diff --git a/homeassistant/components/websocket_api/__init__.py b/homeassistant/components/websocket_api/__init__.py index e7b10e18889..52158d3f1ad 100644 --- a/homeassistant/components/websocket_api/__init__.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -1,11 +1,12 @@ """WebSocket based API for Home Assistant.""" from __future__ import annotations -from typing import cast +from typing import Final, cast import voluptuous as vol from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.typing import ConfigType from homeassistant.loader import bind_hass from . import commands, connection, const, decorators, http, messages # noqa: F401 @@ -34,11 +35,9 @@ from .messages import ( # noqa: F401 result_message, ) -# mypy: allow-untyped-calls, allow-untyped-defs +DOMAIN: Final = const.DOMAIN -DOMAIN = const.DOMAIN - -DEPENDENCIES = ("http",) +DEPENDENCIES: Final[tuple[str]] = ("http",) @bind_hass @@ -53,8 +52,8 @@ def async_register_command( # pylint: disable=protected-access if handler is None: handler = cast(const.WebSocketCommandHandler, command_or_handler) - command = handler._ws_command # type: ignore - schema = handler._ws_schema # type: ignore + command = handler._ws_command # type: ignore[attr-defined] + schema = handler._ws_schema # type: ignore[attr-defined] else: command = command_or_handler handlers = hass.data.get(DOMAIN) @@ -63,8 +62,8 @@ def async_register_command( handlers[command] = (handler, schema) -async def async_setup(hass, config): +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Initialize the websocket API.""" - hass.http.register_view(http.WebsocketAPIView) + hass.http.register_view(http.WebsocketAPIView()) commands.async_register_commands(hass, async_register_command) return True diff --git a/homeassistant/components/websocket_api/auth.py b/homeassistant/components/websocket_api/auth.py index 3c795902900..130ffe82840 100644 --- a/homeassistant/components/websocket_api/auth.py +++ b/homeassistant/components/websocket_api/auth.py @@ -1,22 +1,31 @@ """Handle the auth of a connection.""" +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Final + +from aiohttp.web import Request import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant.auth.models import RefreshToken, User from homeassistant.components.http.ban import process_success_login, process_wrong_login from homeassistant.const import __version__ +from homeassistant.core import HomeAssistant from .connection import ActiveConnection from .error import Disconnect -# mypy: allow-untyped-calls, allow-untyped-defs +if TYPE_CHECKING: + from .http import WebSocketAdapter -TYPE_AUTH = "auth" -TYPE_AUTH_INVALID = "auth_invalid" -TYPE_AUTH_OK = "auth_ok" -TYPE_AUTH_REQUIRED = "auth_required" -AUTH_MESSAGE_SCHEMA = vol.Schema( +TYPE_AUTH: Final = "auth" +TYPE_AUTH_INVALID: Final = "auth_invalid" +TYPE_AUTH_OK: Final = "auth_ok" +TYPE_AUTH_REQUIRED: Final = "auth_required" + +AUTH_MESSAGE_SCHEMA: Final = vol.Schema( { vol.Required("type"): TYPE_AUTH, vol.Exclusive("api_password", "auth"): str, @@ -25,17 +34,17 @@ AUTH_MESSAGE_SCHEMA = vol.Schema( ) -def auth_ok_message(): +def auth_ok_message() -> dict[str, str]: """Return an auth_ok message.""" return {"type": TYPE_AUTH_OK, "ha_version": __version__} -def auth_required_message(): +def auth_required_message() -> dict[str, str]: """Return an auth_required message.""" return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__} -def auth_invalid_message(message): +def auth_invalid_message(message: str) -> dict[str, str]: """Return an auth_invalid message.""" return {"type": TYPE_AUTH_INVALID, "message": message} @@ -43,16 +52,20 @@ def auth_invalid_message(message): class AuthPhase: """Connection that requires client to authenticate first.""" - def __init__(self, logger, hass, send_message, request): + def __init__( + self, + logger: WebSocketAdapter, + hass: HomeAssistant, + send_message: Callable[[str | dict[str, Any]], None], + request: Request, + ) -> None: """Initialize the authentiated connection.""" self._hass = hass self._send_message = send_message self._logger = logger self._request = request - self._authenticated = False - self._connection = None - async def async_handle(self, msg): + async def async_handle(self, msg: dict[str, str]) -> ActiveConnection: """Handle authentication.""" try: msg = AUTH_MESSAGE_SCHEMA(msg) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 53ff6d1da26..179fbcd1a30 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -1,6 +1,10 @@ """Commands part of Websocket API.""" +from __future__ import annotations + import asyncio +from collections.abc import Callable import json +from typing import Any import voluptuous as vol @@ -8,7 +12,7 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL -from homeassistant.core import callback +from homeassistant.core import Context, Event, HomeAssistant, callback from homeassistant.exceptions import ( HomeAssistantError, ServiceNotFound, @@ -17,19 +21,25 @@ from homeassistant.exceptions import ( ) from homeassistant.helpers import config_validation as cv, entity, template from homeassistant.helpers.dispatcher import async_dispatcher_connect -from homeassistant.helpers.event import TrackTemplate, async_track_template_result +from homeassistant.helpers.event import ( + TrackTemplate, + TrackTemplateResult, + async_track_template_result, +) from homeassistant.helpers.json import ExtendedJSONEncoder from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations from . import const, decorators, messages - -# mypy: allow-untyped-calls, allow-untyped-defs +from .connection import ActiveConnection @callback -def async_register_commands(hass, async_reg): +def async_register_commands( + hass: HomeAssistant, + async_reg: Callable[[HomeAssistant, const.WebSocketCommandHandler], None], +) -> None: """Register commands.""" async_reg(hass, handle_call_service) async_reg(hass, handle_entity_source) @@ -49,7 +59,7 @@ def async_register_commands(hass, async_reg): async_reg(hass, handle_unsubscribe_events) -def pong_message(iden): +def pong_message(iden: int) -> dict[str, Any]: """Return a pong message.""" return {"id": iden, "type": "pong"} @@ -61,7 +71,9 @@ def pong_message(iden): vol.Optional("event_type", default=MATCH_ALL): str, } ) -def handle_subscribe_events(hass, connection, msg): +def handle_subscribe_events( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle subscribe events command.""" # Circular dep # pylint: disable=import-outside-toplevel @@ -75,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg): if event_type == EVENT_STATE_CHANGED: @callback - def forward_events(event): + def forward_events(event: Event) -> None: """Forward state changed events to websocket.""" if not connection.user.permissions.check_entity( event.data["entity_id"], POLICY_READ @@ -87,7 +99,7 @@ def handle_subscribe_events(hass, connection, msg): else: @callback - def forward_events(event): + def forward_events(event: Event) -> None: """Forward events to websocket.""" if event.event_type == EVENT_TIME_CHANGED: return @@ -107,11 +119,13 @@ def handle_subscribe_events(hass, connection, msg): vol.Required("type"): "subscribe_bootstrap_integrations", } ) -def handle_subscribe_bootstrap_integrations(hass, connection, msg): +def handle_subscribe_bootstrap_integrations( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle subscribe bootstrap integrations command.""" @callback - def forward_bootstrap_integrations(message): + def forward_bootstrap_integrations(message: dict[str, Any]) -> None: """Forward bootstrap integrations to websocket.""" connection.send_message(messages.event_message(msg["id"], message)) @@ -129,7 +143,9 @@ def handle_subscribe_bootstrap_integrations(hass, connection, msg): vol.Required("subscription"): cv.positive_int, } ) -def handle_unsubscribe_events(hass, connection, msg): +def handle_unsubscribe_events( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle unsubscribe events command.""" subscription = msg["subscription"] @@ -154,7 +170,9 @@ def handle_unsubscribe_events(hass, connection, msg): } ) @decorators.async_response -async def handle_call_service(hass, connection, msg): +async def handle_call_service( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle call service command.""" blocking = True # We do not support templates. @@ -206,7 +224,9 @@ async def handle_call_service(hass, connection, msg): @callback @decorators.websocket_command({vol.Required("type"): "get_states"}) -def handle_get_states(hass, connection, msg): +def handle_get_states( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle get states command.""" if connection.user.permissions.access_all_entities("read"): states = hass.states.async_all() @@ -223,7 +243,9 @@ def handle_get_states(hass, connection, msg): @decorators.websocket_command({vol.Required("type"): "get_services"}) @decorators.async_response -async def handle_get_services(hass, connection, msg): +async def handle_get_services( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle get services command.""" descriptions = await async_get_all_descriptions(hass) connection.send_message(messages.result_message(msg["id"], descriptions)) @@ -231,14 +253,18 @@ async def handle_get_services(hass, connection, msg): @callback @decorators.websocket_command({vol.Required("type"): "get_config"}) -def handle_get_config(hass, connection, msg): +def handle_get_config( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle get config command.""" connection.send_message(messages.result_message(msg["id"], hass.config.as_dict())) @decorators.websocket_command({vol.Required("type"): "manifest/list"}) @decorators.async_response -async def handle_manifest_list(hass, connection, msg): +async def handle_manifest_list( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle integrations command.""" loaded_integrations = async_get_loaded_integrations(hass) integrations = await asyncio.gather( @@ -253,7 +279,9 @@ async def handle_manifest_list(hass, connection, msg): {vol.Required("type"): "manifest/get", vol.Required("integration"): str} ) @decorators.async_response -async def handle_manifest_get(hass, connection, msg): +async def handle_manifest_get( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle integrations command.""" try: integration = await async_get_integration(hass, msg["integration"]) @@ -264,7 +292,9 @@ async def handle_manifest_get(hass, connection, msg): @decorators.websocket_command({vol.Required("type"): "integration/setup_info"}) @decorators.async_response -async def handle_integration_setup_info(hass, connection, msg): +async def handle_integration_setup_info( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle integrations command.""" connection.send_result( msg["id"], @@ -277,7 +307,9 @@ async def handle_integration_setup_info(hass, connection, msg): @callback @decorators.websocket_command({vol.Required("type"): "ping"}) -def handle_ping(hass, connection, msg): +def handle_ping( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle ping command.""" connection.send_message(pong_message(msg["id"])) @@ -293,10 +325,12 @@ def handle_ping(hass, connection, msg): } ) @decorators.async_response -async def handle_render_template(hass, connection, msg): +async def handle_render_template( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle render_template command.""" template_str = msg["template"] - template_obj = template.Template(template_str, hass) + template_obj = template.Template(template_str, hass) # type: ignore[no-untyped-call] variables = msg.get("variables") timeout = msg.get("timeout") info = None @@ -319,7 +353,7 @@ async def handle_render_template(hass, connection, msg): return @callback - def _template_listener(event, updates): + def _template_listener(event: Event, updates: list[TrackTemplateResult]) -> None: nonlocal info track_template_result = updates.pop() result = track_template_result.result @@ -329,7 +363,7 @@ async def handle_render_template(hass, connection, msg): connection.send_message( messages.event_message( - msg["id"], {"result": result, "listeners": info.listeners} # type: ignore + msg["id"], {"result": result, "listeners": info.listeners} # type: ignore[attr-defined] ) ) @@ -356,7 +390,9 @@ async def handle_render_template(hass, connection, msg): @decorators.websocket_command( {vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]} ) -def handle_entity_source(hass, connection, msg): +def handle_entity_source( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle entity source command.""" raw_sources = entity.entity_sources(hass) entity_perm = connection.user.permissions.check_entity @@ -404,7 +440,9 @@ def handle_entity_source(hass, connection, msg): ) @decorators.require_admin @decorators.async_response -async def handle_subscribe_trigger(hass, connection, msg): +async def handle_subscribe_trigger( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle subscribe trigger command.""" # Circular dep # pylint: disable=import-outside-toplevel @@ -413,7 +451,9 @@ async def handle_subscribe_trigger(hass, connection, msg): trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"]) @callback - def forward_triggers(variables, context=None): + def forward_triggers( + variables: dict[str, Any], context: Context | None = None + ) -> None: """Forward events to websocket.""" message = messages.event_message( msg["id"], {"variables": variables, "context": context} @@ -449,7 +489,9 @@ async def handle_subscribe_trigger(hass, connection, msg): ) @decorators.require_admin @decorators.async_response -async def handle_test_condition(hass, connection, msg): +async def handle_test_condition( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle test condition command.""" # Circular dep # pylint: disable=import-outside-toplevel @@ -470,7 +512,9 @@ async def handle_test_condition(hass, connection, msg): ) @decorators.require_admin @decorators.async_response -async def handle_execute_script(hass, connection, msg): +async def handle_execute_script( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] +) -> None: """Handle execute script command.""" # Circular dep # pylint: disable=import-outside-toplevel diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 4e0ba257d59..62c21ef5894 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -3,48 +3,50 @@ from __future__ import annotations import asyncio from collections.abc import Hashable -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable import voluptuous as vol -from homeassistant.core import Context, callback +from homeassistant.auth.models import RefreshToken, User +from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, Unauthorized from . import const, messages -# mypy: allow-untyped-calls, allow-untyped-defs +if TYPE_CHECKING: + from .http import WebSocketAdapter class ActiveConnection: """Handle an active websocket client connection.""" - def __init__(self, logger, hass, send_message, user, refresh_token): + def __init__( + self, + logger: WebSocketAdapter, + hass: HomeAssistant, + send_message: Callable[[str | dict[str, Any]], None], + user: User, + refresh_token: RefreshToken, + ) -> None: """Initialize an active connection.""" self.logger = logger self.hass = hass self.send_message = send_message self.user = user - if refresh_token: - self.refresh_token_id = refresh_token.id - else: - self.refresh_token_id = None - + self.refresh_token_id = refresh_token.id self.subscriptions: dict[Hashable, Callable[[], Any]] = {} self.last_id = 0 - def context(self, msg): + def context(self, msg: dict[str, Any]) -> Context: """Return a context.""" - user = self.user - if user is None: - return Context() - return Context(user_id=user.id) + return Context(user_id=self.user.id) @callback def send_result(self, msg_id: int, result: Any | None = None) -> None: """Send a result message.""" self.send_message(messages.result_message(msg_id, result)) - async def send_big_result(self, msg_id, result): + async def send_big_result(self, msg_id: int, result: Any) -> None: """Send a result message that would be expensive to JSON serialize.""" content = await self.hass.async_add_executor_job( const.JSON_DUMP, messages.result_message(msg_id, result) @@ -57,7 +59,7 @@ class ActiveConnection: self.send_message(messages.error_message(msg_id, code, message)) @callback - def async_handle(self, msg): + def async_handle(self, msg: dict[str, Any]) -> None: """Handle a single incoming message.""" handlers = self.hass.data[const.DOMAIN] @@ -102,13 +104,13 @@ class ActiveConnection: self.last_id = cur_id @callback - def async_close(self): + def async_close(self) -> None: """Close down connection.""" for unsub in self.subscriptions.values(): unsub() @callback - def async_handle_exception(self, msg, err): + def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None: """Handle an exception while processing a handler.""" log_handler = self.logger.error diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py index 7c3f18f856c..69716b97076 100644 --- a/homeassistant/components/websocket_api/const.py +++ b/homeassistant/components/websocket_api/const.py @@ -1,9 +1,11 @@ """Websocket constants.""" +from __future__ import annotations + import asyncio from concurrent import futures from functools import partial import json -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Final from homeassistant.core import HomeAssistant from homeassistant.helpers.json import JSONEncoder @@ -12,37 +14,42 @@ if TYPE_CHECKING: from .connection import ActiveConnection -WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], None] +WebSocketCommandHandler = Callable[ + [HomeAssistant, "ActiveConnection", Dict[str, Any]], None +] +AsyncWebSocketCommandHandler = Callable[ + [HomeAssistant, "ActiveConnection", Dict[str, Any]], Awaitable[None] +] -DOMAIN = "websocket_api" -URL = "/api/websocket" -PENDING_MSG_PEAK = 512 -PENDING_MSG_PEAK_TIME = 5 -MAX_PENDING_MSG = 2048 +DOMAIN: Final = "websocket_api" +URL: Final = "/api/websocket" +PENDING_MSG_PEAK: Final = 512 +PENDING_MSG_PEAK_TIME: Final = 5 +MAX_PENDING_MSG: Final = 2048 -ERR_ID_REUSE = "id_reuse" -ERR_INVALID_FORMAT = "invalid_format" -ERR_NOT_FOUND = "not_found" -ERR_NOT_SUPPORTED = "not_supported" -ERR_HOME_ASSISTANT_ERROR = "home_assistant_error" -ERR_UNKNOWN_COMMAND = "unknown_command" -ERR_UNKNOWN_ERROR = "unknown_error" -ERR_UNAUTHORIZED = "unauthorized" -ERR_TIMEOUT = "timeout" -ERR_TEMPLATE_ERROR = "template_error" +ERR_ID_REUSE: Final = "id_reuse" +ERR_INVALID_FORMAT: Final = "invalid_format" +ERR_NOT_FOUND: Final = "not_found" +ERR_NOT_SUPPORTED: Final = "not_supported" +ERR_HOME_ASSISTANT_ERROR: Final = "home_assistant_error" +ERR_UNKNOWN_COMMAND: Final = "unknown_command" +ERR_UNKNOWN_ERROR: Final = "unknown_error" +ERR_UNAUTHORIZED: Final = "unauthorized" +ERR_TIMEOUT: Final = "timeout" +ERR_TEMPLATE_ERROR: Final = "template_error" -TYPE_RESULT = "result" +TYPE_RESULT: Final = "result" # Define the possible errors that occur when connections are cancelled. # Originally, this was just asyncio.CancelledError, but issue #9546 showed # that futures.CancelledErrors can also occur in some situations. -CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError) +CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError) # Event types -SIGNAL_WEBSOCKET_CONNECTED = "websocket_connected" -SIGNAL_WEBSOCKET_DISCONNECTED = "websocket_disconnected" +SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected" +SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected" # Data used to store the current connection list -DATA_CONNECTIONS = f"{DOMAIN}.connections" +DATA_CONNECTIONS: Final = f"{DOMAIN}.connections" -JSON_DUMP = partial(json.dumps, cls=JSONEncoder, allow_nan=False) +JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, allow_nan=False) diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index cbb0e8563c5..af762cf2d46 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -2,9 +2,10 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable from functools import wraps -from typing import Callable +from typing import Any, Callable + +import voluptuous as vol from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import Unauthorized @@ -12,10 +13,13 @@ from homeassistant.exceptions import Unauthorized from . import const, messages from .connection import ActiveConnection -# mypy: allow-untyped-calls, allow-untyped-defs - -async def _handle_async_response(func, hass, connection, msg): +async def _handle_async_response( + func: const.AsyncWebSocketCommandHandler, + hass: HomeAssistant, + connection: ActiveConnection, + msg: dict[str, Any], +) -> None: """Create a response and handle exception.""" try: await func(hass, connection, msg) @@ -24,13 +28,15 @@ async def _handle_async_response(func, hass, connection, msg): def async_response( - func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]] + func: const.AsyncWebSocketCommandHandler, ) -> const.WebSocketCommandHandler: """Decorate an async function to handle WebSocket API messages.""" @callback @wraps(func) - def schedule_handler(hass, connection, msg): + def schedule_handler( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: """Schedule the handler.""" # As the webserver is now started before the start # event we do not want to block for websocket responders @@ -43,7 +49,9 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand """Websocket decorator to require user to be an admin.""" @wraps(func) - def with_admin(hass, connection, msg): + def with_admin( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: """Check admin and call function.""" user = connection.user @@ -56,34 +64,32 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand def ws_require_user( - only_owner=False, - only_system_user=False, - allow_system_user=True, - only_active_user=True, - only_inactive_user=False, -): + only_owner: bool = False, + only_system_user: bool = False, + allow_system_user: bool = True, + only_active_user: bool = True, + only_inactive_user: bool = False, +) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: """Decorate function validating login user exist in current WS connection. Will write out error message if not authenticated. """ - def validator(func): + def validator(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Decorate func.""" @wraps(func) - def check_current_user(hass, connection, msg): + def check_current_user( + hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any] + ) -> None: """Check current user.""" - def output_error(message_id, message): + def output_error(message_id: str, message: str) -> None: """Output error message.""" connection.send_message( messages.error_message(msg["id"], message_id, message) ) - if connection.user is None: - output_error("no_user", "Not authenticated as a user") - return - if only_owner and not connection.user.is_owner: output_error("only_owner", "Only allowed as owner") return @@ -112,16 +118,16 @@ def ws_require_user( def websocket_command( - schema: dict, + schema: dict[vol.Marker, Any], ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: """Tag a function as a websocket command.""" command = schema["type"] - def decorate(func): + def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Decorate ws command function.""" # pylint: disable=protected-access - func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) - func._ws_command = command + func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] + func._ws_command = command # type: ignore[attr-defined] return func return decorate diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index a84db598fdc..a80ff111f0d 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -2,15 +2,18 @@ from __future__ import annotations import asyncio +from collections.abc import Callable from contextlib import suppress +import datetime as dt import logging +from typing import Any, Final from aiohttp import WSMsgType, web import async_timeout from homeassistant.components.http import HomeAssistantView from homeassistant.const import EVENT_HOMEASSISTANT_STOP -from homeassistant.core import callback +from homeassistant.core import Event, HomeAssistant, callback from homeassistant.helpers.event import async_call_later from .auth import AuthPhase, auth_required_message @@ -27,16 +30,15 @@ from .const import ( from .error import Disconnect from .messages import message_to_json -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs -_WS_LOGGER = logging.getLogger(f"{__name__}.connection") +_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection") class WebsocketAPIView(HomeAssistantView): """View to serve a websockets endpoint.""" - name = "websocketapi" - url = URL - requires_auth = False + name: str = "websocketapi" + url: str = URL + requires_auth: bool = False async def get(self, request: web.Request) -> web.WebSocketResponse: """Handle an incoming websocket connection.""" @@ -46,7 +48,7 @@ class WebsocketAPIView(HomeAssistantView): class WebSocketAdapter(logging.LoggerAdapter): """Add connection id to websocket messages.""" - def process(self, msg, kwargs): + def process(self, msg: str, kwargs: Any) -> tuple[str, Any]: """Add connid to websocket log messages.""" return f'[{self.extra["connid"]}] {msg}', kwargs @@ -54,20 +56,21 @@ class WebSocketAdapter(logging.LoggerAdapter): class WebSocketHandler: """Handle an active websocket client connection.""" - def __init__(self, hass, request): + def __init__(self, hass: HomeAssistant, request: web.Request) -> None: """Initialize an active connection.""" self.hass = hass self.request = request self.wsock: web.WebSocketResponse | None = None self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG) - self._handle_task = None - self._writer_task = None + self._handle_task: asyncio.Task | None = None + self._writer_task: asyncio.Task | None = None self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)}) - self._peak_checker_unsub = None + self._peak_checker_unsub: Callable[[], None] | None = None - async def _writer(self): + async def _writer(self) -> None: """Write outgoing messages.""" # Exceptions if Socket disconnected or cancelled by connection handler + assert self.wsock is not None with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS): while not self.wsock.closed: message = await self._to_write.get() @@ -78,12 +81,12 @@ class WebSocketHandler: await self.wsock.send_str(message) # Clean up the peaker checker when we shut down the writer - if self._peak_checker_unsub: + if self._peak_checker_unsub is not None: self._peak_checker_unsub() self._peak_checker_unsub = None @callback - def _send_message(self, message): + def _send_message(self, message: str | dict[str, Any]) -> None: """Send a message to the client. Closes connection if the client is not reading the messages. @@ -114,7 +117,7 @@ class WebSocketHandler: ) @callback - def _check_write_peak(self, _): + def _check_write_peak(self, _utc_time: dt.datetime) -> None: """Check that we are no longer above the write peak.""" self._peak_checker_unsub = None @@ -129,10 +132,12 @@ class WebSocketHandler: self._cancel() @callback - def _cancel(self): + def _cancel(self) -> None: """Cancel the connection.""" - self._handle_task.cancel() - self._writer_task.cancel() + if self._handle_task is not None: + self._handle_task.cancel() + if self._writer_task is not None: + self._writer_task.cancel() async def async_handle(self) -> web.WebSocketResponse: """Handle a websocket response.""" @@ -143,7 +148,7 @@ class WebSocketHandler: self._handle_task = asyncio.current_task() @callback - def handle_hass_stop(event): + def handle_hass_stop(event: Event) -> None: """Cancel this connection.""" self._cancel() diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py index 736a7ad59f0..8cdda3f8fa3 100644 --- a/homeassistant/components/websocket_api/messages.py +++ b/homeassistant/components/websocket_api/messages.py @@ -3,7 +3,7 @@ from __future__ import annotations from functools import lru_cache import logging -from typing import Any +from typing import Any, Final import voluptuous as vol @@ -17,28 +17,27 @@ from homeassistant.util.yaml.loader import JSON_TYPE from . import const -_LOGGER = logging.getLogger(__name__) -# mypy: allow-untyped-defs +_LOGGER: Final = logging.getLogger(__name__) # Minimal requirements of a message -MINIMAL_MESSAGE_SCHEMA = vol.Schema( +MINIMAL_MESSAGE_SCHEMA: Final = vol.Schema( {vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string}, extra=vol.ALLOW_EXTRA, ) # Base schema to extend by message handlers -BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int}) +BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive_int}) -IDEN_TEMPLATE = "__IDEN__" -IDEN_JSON_TEMPLATE = '"__IDEN__"' +IDEN_TEMPLATE: Final = "__IDEN__" +IDEN_JSON_TEMPLATE: Final = '"__IDEN__"' -def result_message(iden: int, result: Any = None) -> dict: +def result_message(iden: int, result: Any = None) -> dict[str, Any]: """Return a success result message.""" return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} -def error_message(iden: int, code: str, message: str) -> dict: +def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]: """Return an error result message.""" return { "id": iden, @@ -48,7 +47,7 @@ def error_message(iden: int, code: str, message: str) -> dict: } -def event_message(iden: JSON_TYPE, event: Any) -> dict: +def event_message(iden: JSON_TYPE, event: Any) -> dict[str, Any]: """Return an event message.""" return {"id": iden, "type": "event", "event": event} @@ -75,7 +74,7 @@ def _cached_event_message(event: Event) -> str: return message_to_json(event_message(IDEN_TEMPLATE, event)) -def message_to_json(message: Any) -> str: +def message_to_json(message: dict[str, Any]) -> str: """Serialize a websocket message to json.""" try: return const.JSON_DUMP(message) diff --git a/homeassistant/components/websocket_api/permissions.py b/homeassistant/components/websocket_api/permissions.py index 010a18f972c..5dade8eeb2a 100644 --- a/homeassistant/components/websocket_api/permissions.py +++ b/homeassistant/components/websocket_api/permissions.py @@ -2,6 +2,10 @@ Separate file to avoid circular imports. """ +from __future__ import annotations + +from typing import Final + from homeassistant.components.frontend import EVENT_PANELS_UPDATED from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED from homeassistant.components.persistent_notification import ( @@ -22,7 +26,7 @@ from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED # These are events that do not contain any sensitive data # Except for state_changed, which is handled accordingly. -SUBSCRIBE_ALLOWLIST = { +SUBSCRIBE_ALLOWLIST: Final[set[str]] = { EVENT_AREA_REGISTRY_UPDATED, EVENT_COMPONENT_LOADED, EVENT_CORE_CONFIG_UPDATE, diff --git a/homeassistant/components/websocket_api/sensor.py b/homeassistant/components/websocket_api/sensor.py index dfcdc57842e..60d42e97604 100644 --- a/homeassistant/components/websocket_api/sensor.py +++ b/homeassistant/components/websocket_api/sensor.py @@ -1,7 +1,12 @@ """Entity to track connections to websocket API.""" +from __future__ import annotations + +from typing import Any from homeassistant.components.sensor import SensorEntity -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.typing import ConfigType from .const import ( DATA_CONNECTIONS, @@ -9,10 +14,13 @@ from .const import ( SIGNAL_WEBSOCKET_DISCONNECTED, ) -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs - -async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): +async def async_setup_platform( + hass: HomeAssistant, + config: ConfigType, + async_add_entities: AddEntitiesCallback, + discovery_info: dict[str, Any] | None = None, +) -> None: """Set up the API streams platform.""" entity = APICount() @@ -22,11 +30,11 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= class APICount(SensorEntity): """Entity to represent how many people are connected to the stream API.""" - def __init__(self): + def __init__(self) -> None: """Initialize the API count.""" self.count = 0 - async def async_added_to_hass(self): + async def async_added_to_hass(self) -> None: """Added to hass.""" self.async_on_remove( self.hass.helpers.dispatcher.async_dispatcher_connect( @@ -40,21 +48,21 @@ class APICount(SensorEntity): ) @property - def name(self): + def name(self) -> str: """Return name of entity.""" return "Connected clients" @property - def state(self): + def state(self) -> int: """Return current API count.""" return self.count @property - def unit_of_measurement(self): + def unit_of_measurement(self) -> str: """Return the unit of measurement.""" return "clients" @callback - def _update_count(self): + def _update_count(self) -> None: self.count = self.hass.data.get(DATA_CONNECTIONS, 0) self.async_write_ha_state() diff --git a/tests/components/websocket_api/test_connection.py b/tests/components/websocket_api/test_connection.py index 55126ff1333..1d6bf5f2f6b 100644 --- a/tests/components/websocket_api/test_connection.py +++ b/tests/components/websocket_api/test_connection.py @@ -1,6 +1,7 @@ """Test WebSocket Connection class.""" import asyncio import logging +from unittest.mock import Mock import voluptuous as vol @@ -8,6 +9,8 @@ from homeassistant import exceptions from homeassistant.components import websocket_api from homeassistant.components.websocket_api import const +from tests.common import MockUser + async def test_send_big_result(hass, websocket_client): """Test sending big results over the WS.""" @@ -31,8 +34,10 @@ async def test_send_big_result(hass, websocket_client): async def test_exception_handling(): """Test handling of exceptions.""" send_messages = [] + user = MockUser() + refresh_token = Mock() conn = websocket_api.ActiveConnection( - logging.getLogger(__name__), None, send_messages.append, None, None + logging.getLogger(__name__), None, send_messages.append, user, refresh_token ) for (exc, code, err) in (