Add missing type hints to websocket_api (#50915)

pull/50940/head
Ruslan Sayfutdinov 2021-05-21 17:39:18 +01:00 committed by GitHub
parent dc65f279a7
commit 42ff687c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 251 additions and 159 deletions

View File

@ -1,11 +1,12 @@
"""WebSocket based API for Home Assistant."""
from __future__ import annotations
from typing import cast
from typing import Final, cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from . import commands, connection, const, decorators, http, messages # noqa: F401
@ -34,11 +35,9 @@ from .messages import ( # noqa: F401
result_message,
)
# mypy: allow-untyped-calls, allow-untyped-defs
DOMAIN: Final = const.DOMAIN
DOMAIN = const.DOMAIN
DEPENDENCIES = ("http",)
DEPENDENCIES: Final[tuple[str]] = ("http",)
@bind_hass
@ -53,8 +52,8 @@ def async_register_command(
# pylint: disable=protected-access
if handler is None:
handler = cast(const.WebSocketCommandHandler, command_or_handler)
command = handler._ws_command # type: ignore
schema = handler._ws_schema # type: ignore
command = handler._ws_command # type: ignore[attr-defined]
schema = handler._ws_schema # type: ignore[attr-defined]
else:
command = command_or_handler
handlers = hass.data.get(DOMAIN)
@ -63,8 +62,8 @@ def async_register_command(
handlers[command] = (handler, schema)
async def async_setup(hass, config):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Initialize the websocket API."""
hass.http.register_view(http.WebsocketAPIView)
hass.http.register_view(http.WebsocketAPIView())
commands.async_register_commands(hass, async_register_command)
return True

View File

@ -1,22 +1,31 @@
"""Handle the auth of a connection."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Final
from aiohttp.web import Request
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http.ban import process_success_login, process_wrong_login
from homeassistant.const import __version__
from homeassistant.core import HomeAssistant
from .connection import ActiveConnection
from .error import Disconnect
# mypy: allow-untyped-calls, allow-untyped-defs
if TYPE_CHECKING:
from .http import WebSocketAdapter
TYPE_AUTH = "auth"
TYPE_AUTH_INVALID = "auth_invalid"
TYPE_AUTH_OK = "auth_ok"
TYPE_AUTH_REQUIRED = "auth_required"
AUTH_MESSAGE_SCHEMA = vol.Schema(
TYPE_AUTH: Final = "auth"
TYPE_AUTH_INVALID: Final = "auth_invalid"
TYPE_AUTH_OK: Final = "auth_ok"
TYPE_AUTH_REQUIRED: Final = "auth_required"
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
{
vol.Required("type"): TYPE_AUTH,
vol.Exclusive("api_password", "auth"): str,
@ -25,17 +34,17 @@ AUTH_MESSAGE_SCHEMA = vol.Schema(
)
def auth_ok_message():
def auth_ok_message() -> dict[str, str]:
"""Return an auth_ok message."""
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
def auth_required_message():
def auth_required_message() -> dict[str, str]:
"""Return an auth_required message."""
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
def auth_invalid_message(message):
def auth_invalid_message(message: str) -> dict[str, str]:
"""Return an auth_invalid message."""
return {"type": TYPE_AUTH_INVALID, "message": message}
@ -43,16 +52,20 @@ def auth_invalid_message(message):
class AuthPhase:
"""Connection that requires client to authenticate first."""
def __init__(self, logger, hass, send_message, request):
def __init__(
self,
logger: WebSocketAdapter,
hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None],
request: Request,
) -> None:
"""Initialize the authentiated connection."""
self._hass = hass
self._send_message = send_message
self._logger = logger
self._request = request
self._authenticated = False
self._connection = None
async def async_handle(self, msg):
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
"""Handle authentication."""
try:
msg = AUTH_MESSAGE_SCHEMA(msg)

View File

@ -1,6 +1,10 @@
"""Commands part of Websocket API."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
import json
from typing import Any
import voluptuous as vol
@ -8,7 +12,7 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
from homeassistant.core import callback
from homeassistant.core import Context, Event, HomeAssistant, callback
from homeassistant.exceptions import (
HomeAssistantError,
ServiceNotFound,
@ -17,19 +21,25 @@ from homeassistant.exceptions import (
)
from homeassistant.helpers import config_validation as cv, entity, template
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
from homeassistant.helpers.event import (
TrackTemplate,
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
from . import const, decorators, messages
# mypy: allow-untyped-calls, allow-untyped-defs
from .connection import ActiveConnection
@callback
def async_register_commands(hass, async_reg):
def async_register_commands(
hass: HomeAssistant,
async_reg: Callable[[HomeAssistant, const.WebSocketCommandHandler], None],
) -> None:
"""Register commands."""
async_reg(hass, handle_call_service)
async_reg(hass, handle_entity_source)
@ -49,7 +59,7 @@ def async_register_commands(hass, async_reg):
async_reg(hass, handle_unsubscribe_events)
def pong_message(iden):
def pong_message(iden: int) -> dict[str, Any]:
"""Return a pong message."""
return {"id": iden, "type": "pong"}
@ -61,7 +71,9 @@ def pong_message(iden):
vol.Optional("event_type", default=MATCH_ALL): str,
}
)
def handle_subscribe_events(hass, connection, msg):
def handle_subscribe_events(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe events command."""
# Circular dep
# pylint: disable=import-outside-toplevel
@ -75,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg):
if event_type == EVENT_STATE_CHANGED:
@callback
def forward_events(event):
def forward_events(event: Event) -> None:
"""Forward state changed events to websocket."""
if not connection.user.permissions.check_entity(
event.data["entity_id"], POLICY_READ
@ -87,7 +99,7 @@ def handle_subscribe_events(hass, connection, msg):
else:
@callback
def forward_events(event):
def forward_events(event: Event) -> None:
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
@ -107,11 +119,13 @@ def handle_subscribe_events(hass, connection, msg):
vol.Required("type"): "subscribe_bootstrap_integrations",
}
)
def handle_subscribe_bootstrap_integrations(hass, connection, msg):
def handle_subscribe_bootstrap_integrations(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe bootstrap integrations command."""
@callback
def forward_bootstrap_integrations(message):
def forward_bootstrap_integrations(message: dict[str, Any]) -> None:
"""Forward bootstrap integrations to websocket."""
connection.send_message(messages.event_message(msg["id"], message))
@ -129,7 +143,9 @@ def handle_subscribe_bootstrap_integrations(hass, connection, msg):
vol.Required("subscription"): cv.positive_int,
}
)
def handle_unsubscribe_events(hass, connection, msg):
def handle_unsubscribe_events(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle unsubscribe events command."""
subscription = msg["subscription"]
@ -154,7 +170,9 @@ def handle_unsubscribe_events(hass, connection, msg):
}
)
@decorators.async_response
async def handle_call_service(hass, connection, msg):
async def handle_call_service(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle call service command."""
blocking = True
# We do not support templates.
@ -206,7 +224,9 @@ async def handle_call_service(hass, connection, msg):
@callback
@decorators.websocket_command({vol.Required("type"): "get_states"})
def handle_get_states(hass, connection, msg):
def handle_get_states(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get states command."""
if connection.user.permissions.access_all_entities("read"):
states = hass.states.async_all()
@ -223,7 +243,9 @@ def handle_get_states(hass, connection, msg):
@decorators.websocket_command({vol.Required("type"): "get_services"})
@decorators.async_response
async def handle_get_services(hass, connection, msg):
async def handle_get_services(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get services command."""
descriptions = await async_get_all_descriptions(hass)
connection.send_message(messages.result_message(msg["id"], descriptions))
@ -231,14 +253,18 @@ async def handle_get_services(hass, connection, msg):
@callback
@decorators.websocket_command({vol.Required("type"): "get_config"})
def handle_get_config(hass, connection, msg):
def handle_get_config(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle get config command."""
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
@decorators.async_response
async def handle_manifest_list(hass, connection, msg):
async def handle_manifest_list(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command."""
loaded_integrations = async_get_loaded_integrations(hass)
integrations = await asyncio.gather(
@ -253,7 +279,9 @@ async def handle_manifest_list(hass, connection, msg):
{vol.Required("type"): "manifest/get", vol.Required("integration"): str}
)
@decorators.async_response
async def handle_manifest_get(hass, connection, msg):
async def handle_manifest_get(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command."""
try:
integration = await async_get_integration(hass, msg["integration"])
@ -264,7 +292,9 @@ async def handle_manifest_get(hass, connection, msg):
@decorators.websocket_command({vol.Required("type"): "integration/setup_info"})
@decorators.async_response
async def handle_integration_setup_info(hass, connection, msg):
async def handle_integration_setup_info(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle integrations command."""
connection.send_result(
msg["id"],
@ -277,7 +307,9 @@ async def handle_integration_setup_info(hass, connection, msg):
@callback
@decorators.websocket_command({vol.Required("type"): "ping"})
def handle_ping(hass, connection, msg):
def handle_ping(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle ping command."""
connection.send_message(pong_message(msg["id"]))
@ -293,10 +325,12 @@ def handle_ping(hass, connection, msg):
}
)
@decorators.async_response
async def handle_render_template(hass, connection, msg):
async def handle_render_template(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle render_template command."""
template_str = msg["template"]
template_obj = template.Template(template_str, hass)
template_obj = template.Template(template_str, hass) # type: ignore[no-untyped-call]
variables = msg.get("variables")
timeout = msg.get("timeout")
info = None
@ -319,7 +353,7 @@ async def handle_render_template(hass, connection, msg):
return
@callback
def _template_listener(event, updates):
def _template_listener(event: Event, updates: list[TrackTemplateResult]) -> None:
nonlocal info
track_template_result = updates.pop()
result = track_template_result.result
@ -329,7 +363,7 @@ async def handle_render_template(hass, connection, msg):
connection.send_message(
messages.event_message(
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore[attr-defined]
)
)
@ -356,7 +390,9 @@ async def handle_render_template(hass, connection, msg):
@decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
)
def handle_entity_source(hass, connection, msg):
def handle_entity_source(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle entity source command."""
raw_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity
@ -404,7 +440,9 @@ def handle_entity_source(hass, connection, msg):
)
@decorators.require_admin
@decorators.async_response
async def handle_subscribe_trigger(hass, connection, msg):
async def handle_subscribe_trigger(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle subscribe trigger command."""
# Circular dep
# pylint: disable=import-outside-toplevel
@ -413,7 +451,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
@callback
def forward_triggers(variables, context=None):
def forward_triggers(
variables: dict[str, Any], context: Context | None = None
) -> None:
"""Forward events to websocket."""
message = messages.event_message(
msg["id"], {"variables": variables, "context": context}
@ -449,7 +489,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
)
@decorators.require_admin
@decorators.async_response
async def handle_test_condition(hass, connection, msg):
async def handle_test_condition(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle test condition command."""
# Circular dep
# pylint: disable=import-outside-toplevel
@ -470,7 +512,9 @@ async def handle_test_condition(hass, connection, msg):
)
@decorators.require_admin
@decorators.async_response
async def handle_execute_script(hass, connection, msg):
async def handle_execute_script(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle execute script command."""
# Circular dep
# pylint: disable=import-outside-toplevel

View File

@ -3,48 +3,50 @@ from __future__ import annotations
import asyncio
from collections.abc import Hashable
from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable
import voluptuous as vol
from homeassistant.core import Context, callback
from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from . import const, messages
# mypy: allow-untyped-calls, allow-untyped-defs
if TYPE_CHECKING:
from .http import WebSocketAdapter
class ActiveConnection:
"""Handle an active websocket client connection."""
def __init__(self, logger, hass, send_message, user, refresh_token):
def __init__(
self,
logger: WebSocketAdapter,
hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None],
user: User,
refresh_token: RefreshToken,
) -> None:
"""Initialize an active connection."""
self.logger = logger
self.hass = hass
self.send_message = send_message
self.user = user
if refresh_token:
self.refresh_token_id = refresh_token.id
else:
self.refresh_token_id = None
self.refresh_token_id = refresh_token.id
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
self.last_id = 0
def context(self, msg):
def context(self, msg: dict[str, Any]) -> Context:
"""Return a context."""
user = self.user
if user is None:
return Context()
return Context(user_id=user.id)
return Context(user_id=self.user.id)
@callback
def send_result(self, msg_id: int, result: Any | None = None) -> None:
"""Send a result message."""
self.send_message(messages.result_message(msg_id, result))
async def send_big_result(self, msg_id, result):
async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job(
const.JSON_DUMP, messages.result_message(msg_id, result)
@ -57,7 +59,7 @@ class ActiveConnection:
self.send_message(messages.error_message(msg_id, code, message))
@callback
def async_handle(self, msg):
def async_handle(self, msg: dict[str, Any]) -> None:
"""Handle a single incoming message."""
handlers = self.hass.data[const.DOMAIN]
@ -102,13 +104,13 @@ class ActiveConnection:
self.last_id = cur_id
@callback
def async_close(self):
def async_close(self) -> None:
"""Close down connection."""
for unsub in self.subscriptions.values():
unsub()
@callback
def async_handle_exception(self, msg, err):
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
"""Handle an exception while processing a handler."""
log_handler = self.logger.error

View File

@ -1,9 +1,11 @@
"""Websocket constants."""
from __future__ import annotations
import asyncio
from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Final
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
@ -12,37 +14,42 @@ if TYPE_CHECKING:
from .connection import ActiveConnection
WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], None]
WebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", Dict[str, Any]], None
]
AsyncWebSocketCommandHandler = Callable[
[HomeAssistant, "ActiveConnection", Dict[str, Any]], Awaitable[None]
]
DOMAIN = "websocket_api"
URL = "/api/websocket"
PENDING_MSG_PEAK = 512
PENDING_MSG_PEAK_TIME = 5
MAX_PENDING_MSG = 2048
DOMAIN: Final = "websocket_api"
URL: Final = "/api/websocket"
PENDING_MSG_PEAK: Final = 512
PENDING_MSG_PEAK_TIME: Final = 5
MAX_PENDING_MSG: Final = 2048
ERR_ID_REUSE = "id_reuse"
ERR_INVALID_FORMAT = "invalid_format"
ERR_NOT_FOUND = "not_found"
ERR_NOT_SUPPORTED = "not_supported"
ERR_HOME_ASSISTANT_ERROR = "home_assistant_error"
ERR_UNKNOWN_COMMAND = "unknown_command"
ERR_UNKNOWN_ERROR = "unknown_error"
ERR_UNAUTHORIZED = "unauthorized"
ERR_TIMEOUT = "timeout"
ERR_TEMPLATE_ERROR = "template_error"
ERR_ID_REUSE: Final = "id_reuse"
ERR_INVALID_FORMAT: Final = "invalid_format"
ERR_NOT_FOUND: Final = "not_found"
ERR_NOT_SUPPORTED: Final = "not_supported"
ERR_HOME_ASSISTANT_ERROR: Final = "home_assistant_error"
ERR_UNKNOWN_COMMAND: Final = "unknown_command"
ERR_UNKNOWN_ERROR: Final = "unknown_error"
ERR_UNAUTHORIZED: Final = "unauthorized"
ERR_TIMEOUT: Final = "timeout"
ERR_TEMPLATE_ERROR: Final = "template_error"
TYPE_RESULT = "result"
TYPE_RESULT: Final = "result"
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
# Event types
SIGNAL_WEBSOCKET_CONNECTED = "websocket_connected"
SIGNAL_WEBSOCKET_DISCONNECTED = "websocket_disconnected"
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
# Data used to store the current connection list
DATA_CONNECTIONS = f"{DOMAIN}.connections"
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
JSON_DUMP = partial(json.dumps, cls=JSONEncoder, allow_nan=False)
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, allow_nan=False)

View File

@ -2,9 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable
from functools import wraps
from typing import Callable
from typing import Any, Callable
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import Unauthorized
@ -12,10 +13,13 @@ from homeassistant.exceptions import Unauthorized
from . import const, messages
from .connection import ActiveConnection
# mypy: allow-untyped-calls, allow-untyped-defs
async def _handle_async_response(func, hass, connection, msg):
async def _handle_async_response(
func: const.AsyncWebSocketCommandHandler,
hass: HomeAssistant,
connection: ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
@ -24,13 +28,15 @@ async def _handle_async_response(func, hass, connection, msg):
def async_response(
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
func: const.AsyncWebSocketCommandHandler,
) -> const.WebSocketCommandHandler:
"""Decorate an async function to handle WebSocket API messages."""
@callback
@wraps(func)
def schedule_handler(hass, connection, msg):
def schedule_handler(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Schedule the handler."""
# As the webserver is now started before the start
# event we do not want to block for websocket responders
@ -43,7 +49,9 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
"""Websocket decorator to require user to be an admin."""
@wraps(func)
def with_admin(hass, connection, msg):
def with_admin(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Check admin and call function."""
user = connection.user
@ -56,34 +64,32 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
def ws_require_user(
only_owner=False,
only_system_user=False,
allow_system_user=True,
only_active_user=True,
only_inactive_user=False,
):
only_owner: bool = False,
only_system_user: bool = False,
allow_system_user: bool = True,
only_active_user: bool = True,
only_inactive_user: bool = False,
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated.
"""
def validator(func):
def validator(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate func."""
@wraps(func)
def check_current_user(hass, connection, msg):
def check_current_user(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Check current user."""
def output_error(message_id, message):
def output_error(message_id: str, message: str) -> None:
"""Output error message."""
connection.send_message(
messages.error_message(msg["id"], message_id, message)
)
if connection.user is None:
output_error("no_user", "Not authenticated as a user")
return
if only_owner and not connection.user.is_owner:
output_error("only_owner", "Only allowed as owner")
return
@ -112,16 +118,16 @@ def ws_require_user(
def websocket_command(
schema: dict,
schema: dict[vol.Marker, Any],
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
"""Tag a function as a websocket command."""
command = schema["type"]
def decorate(func):
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
"""Decorate ws command function."""
# pylint: disable=protected-access
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema)
func._ws_command = command
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
func._ws_command = command # type: ignore[attr-defined]
return func
return decorate

View File

@ -2,15 +2,18 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from contextlib import suppress
import datetime as dt
import logging
from typing import Any, Final
from aiohttp import WSMsgType, web
import async_timeout
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from .auth import AuthPhase, auth_required_message
@ -27,16 +30,15 @@ from .const import (
from .error import Disconnect
from .messages import message_to_json
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_WS_LOGGER = logging.getLogger(f"{__name__}.connection")
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name = "websocketapi"
url = URL
requires_auth = False
name: str = "websocketapi"
url: str = URL
requires_auth: bool = False
async def get(self, request: web.Request) -> web.WebSocketResponse:
"""Handle an incoming websocket connection."""
@ -46,7 +48,7 @@ class WebsocketAPIView(HomeAssistantView):
class WebSocketAdapter(logging.LoggerAdapter):
"""Add connection id to websocket messages."""
def process(self, msg, kwargs):
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages."""
return f'[{self.extra["connid"]}] {msg}', kwargs
@ -54,20 +56,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
class WebSocketHandler:
"""Handle an active websocket client connection."""
def __init__(self, hass, request):
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock: web.WebSocketResponse | None = None
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task = None
self._writer_task = None
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub = None
self._peak_checker_unsub: Callable[[], None] | None = None
async def _writer(self):
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
assert self.wsock is not None
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
message = await self._to_write.get()
@ -78,12 +81,12 @@ class WebSocketHandler:
await self.wsock.send_str(message)
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub:
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback
def _send_message(self, message):
def _send_message(self, message: str | dict[str, Any]) -> None:
"""Send a message to the client.
Closes connection if the client is not reading the messages.
@ -114,7 +117,7 @@ class WebSocketHandler:
)
@callback
def _check_write_peak(self, _):
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
"""Check that we are no longer above the write peak."""
self._peak_checker_unsub = None
@ -129,10 +132,12 @@ class WebSocketHandler:
self._cancel()
@callback
def _cancel(self):
def _cancel(self) -> None:
"""Cancel the connection."""
self._handle_task.cancel()
self._writer_task.cancel()
if self._handle_task is not None:
self._handle_task.cancel()
if self._writer_task is not None:
self._writer_task.cancel()
async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response."""
@ -143,7 +148,7 @@ class WebSocketHandler:
self._handle_task = asyncio.current_task()
@callback
def handle_hass_stop(event):
def handle_hass_stop(event: Event) -> None:
"""Cancel this connection."""
self._cancel()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from functools import lru_cache
import logging
from typing import Any
from typing import Any, Final
import voluptuous as vol
@ -17,28 +17,27 @@ from homeassistant.util.yaml.loader import JSON_TYPE
from . import const
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs
_LOGGER: Final = logging.getLogger(__name__)
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema(
MINIMAL_MESSAGE_SCHEMA: Final = vol.Schema(
{vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string},
extra=vol.ALLOW_EXTRA,
)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int})
BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive_int})
IDEN_TEMPLATE = "__IDEN__"
IDEN_JSON_TEMPLATE = '"__IDEN__"'
IDEN_TEMPLATE: Final = "__IDEN__"
IDEN_JSON_TEMPLATE: Final = '"__IDEN__"'
def result_message(iden: int, result: Any = None) -> dict:
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
def error_message(iden: int, code: str, message: str) -> dict:
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
"""Return an error result message."""
return {
"id": iden,
@ -48,7 +47,7 @@ def error_message(iden: int, code: str, message: str) -> dict:
}
def event_message(iden: JSON_TYPE, event: Any) -> dict:
def event_message(iden: JSON_TYPE, event: Any) -> dict[str, Any]:
"""Return an event message."""
return {"id": iden, "type": "event", "event": event}
@ -75,7 +74,7 @@ def _cached_event_message(event: Event) -> str:
return message_to_json(event_message(IDEN_TEMPLATE, event))
def message_to_json(message: Any) -> str:
def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json."""
try:
return const.JSON_DUMP(message)

View File

@ -2,6 +2,10 @@
Separate file to avoid circular imports.
"""
from __future__ import annotations
from typing import Final
from homeassistant.components.frontend import EVENT_PANELS_UPDATED
from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED
from homeassistant.components.persistent_notification import (
@ -22,7 +26,7 @@ from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
# These are events that do not contain any sensitive data
# Except for state_changed, which is handled accordingly.
SUBSCRIBE_ALLOWLIST = {
SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
EVENT_AREA_REGISTRY_UPDATED,
EVENT_COMPONENT_LOADED,
EVENT_CORE_CONFIG_UPDATE,

View File

@ -1,7 +1,12 @@
"""Entity to track connections to websocket API."""
from __future__ import annotations
from typing import Any
from homeassistant.components.sensor import SensorEntity
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType
from .const import (
DATA_CONNECTIONS,
@ -9,10 +14,13 @@ from .const import (
SIGNAL_WEBSOCKET_DISCONNECTED,
)
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
async def async_setup_platform(
hass: HomeAssistant,
config: ConfigType,
async_add_entities: AddEntitiesCallback,
discovery_info: dict[str, Any] | None = None,
) -> None:
"""Set up the API streams platform."""
entity = APICount()
@ -22,11 +30,11 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
class APICount(SensorEntity):
"""Entity to represent how many people are connected to the stream API."""
def __init__(self):
def __init__(self) -> None:
"""Initialize the API count."""
self.count = 0
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Added to hass."""
self.async_on_remove(
self.hass.helpers.dispatcher.async_dispatcher_connect(
@ -40,21 +48,21 @@ class APICount(SensorEntity):
)
@property
def name(self):
def name(self) -> str:
"""Return name of entity."""
return "Connected clients"
@property
def state(self):
def state(self) -> int:
"""Return current API count."""
return self.count
@property
def unit_of_measurement(self):
def unit_of_measurement(self) -> str:
"""Return the unit of measurement."""
return "clients"
@callback
def _update_count(self):
def _update_count(self) -> None:
self.count = self.hass.data.get(DATA_CONNECTIONS, 0)
self.async_write_ha_state()

View File

@ -1,6 +1,7 @@
"""Test WebSocket Connection class."""
import asyncio
import logging
from unittest.mock import Mock
import voluptuous as vol
@ -8,6 +9,8 @@ from homeassistant import exceptions
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api import const
from tests.common import MockUser
async def test_send_big_result(hass, websocket_client):
"""Test sending big results over the WS."""
@ -31,8 +34,10 @@ async def test_send_big_result(hass, websocket_client):
async def test_exception_handling():
"""Test handling of exceptions."""
send_messages = []
user = MockUser()
refresh_token = Mock()
conn = websocket_api.ActiveConnection(
logging.getLogger(__name__), None, send_messages.append, None, None
logging.getLogger(__name__), None, send_messages.append, user, refresh_token
)
for (exc, code, err) in (