Add missing type hints to websocket_api (#50915)
parent
dc65f279a7
commit
42ff687c32
|
@ -1,11 +1,12 @@
|
|||
"""WebSocket based API for Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import Final, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from . import commands, connection, const, decorators, http, messages # noqa: F401
|
||||
|
@ -34,11 +35,9 @@ from .messages import ( # noqa: F401
|
|||
result_message,
|
||||
)
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
DOMAIN: Final = const.DOMAIN
|
||||
|
||||
DOMAIN = const.DOMAIN
|
||||
|
||||
DEPENDENCIES = ("http",)
|
||||
DEPENDENCIES: Final[tuple[str]] = ("http",)
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -53,8 +52,8 @@ def async_register_command(
|
|||
# pylint: disable=protected-access
|
||||
if handler is None:
|
||||
handler = cast(const.WebSocketCommandHandler, command_or_handler)
|
||||
command = handler._ws_command # type: ignore
|
||||
schema = handler._ws_schema # type: ignore
|
||||
command = handler._ws_command # type: ignore[attr-defined]
|
||||
schema = handler._ws_schema # type: ignore[attr-defined]
|
||||
else:
|
||||
command = command_or_handler
|
||||
handlers = hass.data.get(DOMAIN)
|
||||
|
@ -63,8 +62,8 @@ def async_register_command(
|
|||
handlers[command] = (handler, schema)
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Initialize the websocket API."""
|
||||
hass.http.register_view(http.WebsocketAPIView)
|
||||
hass.http.register_view(http.WebsocketAPIView())
|
||||
commands.async_register_commands(hass, async_register_command)
|
||||
return True
|
||||
|
|
|
@ -1,22 +1,31 @@
|
|||
"""Handle the auth of a connection."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from aiohttp.web import Request
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .connection import ActiveConnection
|
||||
from .error import Disconnect
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
if TYPE_CHECKING:
|
||||
from .http import WebSocketAdapter
|
||||
|
||||
TYPE_AUTH = "auth"
|
||||
TYPE_AUTH_INVALID = "auth_invalid"
|
||||
TYPE_AUTH_OK = "auth_ok"
|
||||
TYPE_AUTH_REQUIRED = "auth_required"
|
||||
|
||||
AUTH_MESSAGE_SCHEMA = vol.Schema(
|
||||
TYPE_AUTH: Final = "auth"
|
||||
TYPE_AUTH_INVALID: Final = "auth_invalid"
|
||||
TYPE_AUTH_OK: Final = "auth_ok"
|
||||
TYPE_AUTH_REQUIRED: Final = "auth_required"
|
||||
|
||||
AUTH_MESSAGE_SCHEMA: Final = vol.Schema(
|
||||
{
|
||||
vol.Required("type"): TYPE_AUTH,
|
||||
vol.Exclusive("api_password", "auth"): str,
|
||||
|
@ -25,17 +34,17 @@ AUTH_MESSAGE_SCHEMA = vol.Schema(
|
|||
)
|
||||
|
||||
|
||||
def auth_ok_message():
|
||||
def auth_ok_message() -> dict[str, str]:
|
||||
"""Return an auth_ok message."""
|
||||
return {"type": TYPE_AUTH_OK, "ha_version": __version__}
|
||||
|
||||
|
||||
def auth_required_message():
|
||||
def auth_required_message() -> dict[str, str]:
|
||||
"""Return an auth_required message."""
|
||||
return {"type": TYPE_AUTH_REQUIRED, "ha_version": __version__}
|
||||
|
||||
|
||||
def auth_invalid_message(message):
|
||||
def auth_invalid_message(message: str) -> dict[str, str]:
|
||||
"""Return an auth_invalid message."""
|
||||
return {"type": TYPE_AUTH_INVALID, "message": message}
|
||||
|
||||
|
@ -43,16 +52,20 @@ def auth_invalid_message(message):
|
|||
class AuthPhase:
|
||||
"""Connection that requires client to authenticate first."""
|
||||
|
||||
def __init__(self, logger, hass, send_message, request):
|
||||
def __init__(
|
||||
self,
|
||||
logger: WebSocketAdapter,
|
||||
hass: HomeAssistant,
|
||||
send_message: Callable[[str | dict[str, Any]], None],
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Initialize the authentiated connection."""
|
||||
self._hass = hass
|
||||
self._send_message = send_message
|
||||
self._logger = logger
|
||||
self._request = request
|
||||
self._authenticated = False
|
||||
self._connection = None
|
||||
|
||||
async def async_handle(self, msg):
|
||||
async def async_handle(self, msg: dict[str, str]) -> ActiveConnection:
|
||||
"""Handle authentication."""
|
||||
try:
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
"""Commands part of Websocket API."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -8,7 +12,7 @@ from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
|
|||
from homeassistant.bootstrap import SIGNAL_BOOTSTRAP_INTEGRATONS
|
||||
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
|
||||
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Context, Event, HomeAssistant, callback
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
ServiceNotFound,
|
||||
|
@ -17,19 +21,25 @@ from homeassistant.exceptions import (
|
|||
)
|
||||
from homeassistant.helpers import config_validation as cv, entity, template
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.event import TrackTemplate, async_track_template_result
|
||||
from homeassistant.helpers.event import (
|
||||
TrackTemplate,
|
||||
TrackTemplateResult,
|
||||
async_track_template_result,
|
||||
)
|
||||
from homeassistant.helpers.json import ExtendedJSONEncoder
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
|
||||
|
||||
from . import const, decorators, messages
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
from .connection import ActiveConnection
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_commands(hass, async_reg):
|
||||
def async_register_commands(
|
||||
hass: HomeAssistant,
|
||||
async_reg: Callable[[HomeAssistant, const.WebSocketCommandHandler], None],
|
||||
) -> None:
|
||||
"""Register commands."""
|
||||
async_reg(hass, handle_call_service)
|
||||
async_reg(hass, handle_entity_source)
|
||||
|
@ -49,7 +59,7 @@ def async_register_commands(hass, async_reg):
|
|||
async_reg(hass, handle_unsubscribe_events)
|
||||
|
||||
|
||||
def pong_message(iden):
|
||||
def pong_message(iden: int) -> dict[str, Any]:
|
||||
"""Return a pong message."""
|
||||
return {"id": iden, "type": "pong"}
|
||||
|
||||
|
@ -61,7 +71,9 @@ def pong_message(iden):
|
|||
vol.Optional("event_type", default=MATCH_ALL): str,
|
||||
}
|
||||
)
|
||||
def handle_subscribe_events(hass, connection, msg):
|
||||
def handle_subscribe_events(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe events command."""
|
||||
# Circular dep
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
@ -75,7 +87,7 @@ def handle_subscribe_events(hass, connection, msg):
|
|||
if event_type == EVENT_STATE_CHANGED:
|
||||
|
||||
@callback
|
||||
def forward_events(event):
|
||||
def forward_events(event: Event) -> None:
|
||||
"""Forward state changed events to websocket."""
|
||||
if not connection.user.permissions.check_entity(
|
||||
event.data["entity_id"], POLICY_READ
|
||||
|
@ -87,7 +99,7 @@ def handle_subscribe_events(hass, connection, msg):
|
|||
else:
|
||||
|
||||
@callback
|
||||
def forward_events(event):
|
||||
def forward_events(event: Event) -> None:
|
||||
"""Forward events to websocket."""
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
@ -107,11 +119,13 @@ def handle_subscribe_events(hass, connection, msg):
|
|||
vol.Required("type"): "subscribe_bootstrap_integrations",
|
||||
}
|
||||
)
|
||||
def handle_subscribe_bootstrap_integrations(hass, connection, msg):
|
||||
def handle_subscribe_bootstrap_integrations(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe bootstrap integrations command."""
|
||||
|
||||
@callback
|
||||
def forward_bootstrap_integrations(message):
|
||||
def forward_bootstrap_integrations(message: dict[str, Any]) -> None:
|
||||
"""Forward bootstrap integrations to websocket."""
|
||||
connection.send_message(messages.event_message(msg["id"], message))
|
||||
|
||||
|
@ -129,7 +143,9 @@ def handle_subscribe_bootstrap_integrations(hass, connection, msg):
|
|||
vol.Required("subscription"): cv.positive_int,
|
||||
}
|
||||
)
|
||||
def handle_unsubscribe_events(hass, connection, msg):
|
||||
def handle_unsubscribe_events(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle unsubscribe events command."""
|
||||
subscription = msg["subscription"]
|
||||
|
||||
|
@ -154,7 +170,9 @@ def handle_unsubscribe_events(hass, connection, msg):
|
|||
}
|
||||
)
|
||||
@decorators.async_response
|
||||
async def handle_call_service(hass, connection, msg):
|
||||
async def handle_call_service(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle call service command."""
|
||||
blocking = True
|
||||
# We do not support templates.
|
||||
|
@ -206,7 +224,9 @@ async def handle_call_service(hass, connection, msg):
|
|||
|
||||
@callback
|
||||
@decorators.websocket_command({vol.Required("type"): "get_states"})
|
||||
def handle_get_states(hass, connection, msg):
|
||||
def handle_get_states(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle get states command."""
|
||||
if connection.user.permissions.access_all_entities("read"):
|
||||
states = hass.states.async_all()
|
||||
|
@ -223,7 +243,9 @@ def handle_get_states(hass, connection, msg):
|
|||
|
||||
@decorators.websocket_command({vol.Required("type"): "get_services"})
|
||||
@decorators.async_response
|
||||
async def handle_get_services(hass, connection, msg):
|
||||
async def handle_get_services(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle get services command."""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
connection.send_message(messages.result_message(msg["id"], descriptions))
|
||||
|
@ -231,14 +253,18 @@ async def handle_get_services(hass, connection, msg):
|
|||
|
||||
@callback
|
||||
@decorators.websocket_command({vol.Required("type"): "get_config"})
|
||||
def handle_get_config(hass, connection, msg):
|
||||
def handle_get_config(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle get config command."""
|
||||
connection.send_message(messages.result_message(msg["id"], hass.config.as_dict()))
|
||||
|
||||
|
||||
@decorators.websocket_command({vol.Required("type"): "manifest/list"})
|
||||
@decorators.async_response
|
||||
async def handle_manifest_list(hass, connection, msg):
|
||||
async def handle_manifest_list(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle integrations command."""
|
||||
loaded_integrations = async_get_loaded_integrations(hass)
|
||||
integrations = await asyncio.gather(
|
||||
|
@ -253,7 +279,9 @@ async def handle_manifest_list(hass, connection, msg):
|
|||
{vol.Required("type"): "manifest/get", vol.Required("integration"): str}
|
||||
)
|
||||
@decorators.async_response
|
||||
async def handle_manifest_get(hass, connection, msg):
|
||||
async def handle_manifest_get(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle integrations command."""
|
||||
try:
|
||||
integration = await async_get_integration(hass, msg["integration"])
|
||||
|
@ -264,7 +292,9 @@ async def handle_manifest_get(hass, connection, msg):
|
|||
|
||||
@decorators.websocket_command({vol.Required("type"): "integration/setup_info"})
|
||||
@decorators.async_response
|
||||
async def handle_integration_setup_info(hass, connection, msg):
|
||||
async def handle_integration_setup_info(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle integrations command."""
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
|
@ -277,7 +307,9 @@ async def handle_integration_setup_info(hass, connection, msg):
|
|||
|
||||
@callback
|
||||
@decorators.websocket_command({vol.Required("type"): "ping"})
|
||||
def handle_ping(hass, connection, msg):
|
||||
def handle_ping(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle ping command."""
|
||||
connection.send_message(pong_message(msg["id"]))
|
||||
|
||||
|
@ -293,10 +325,12 @@ def handle_ping(hass, connection, msg):
|
|||
}
|
||||
)
|
||||
@decorators.async_response
|
||||
async def handle_render_template(hass, connection, msg):
|
||||
async def handle_render_template(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle render_template command."""
|
||||
template_str = msg["template"]
|
||||
template_obj = template.Template(template_str, hass)
|
||||
template_obj = template.Template(template_str, hass) # type: ignore[no-untyped-call]
|
||||
variables = msg.get("variables")
|
||||
timeout = msg.get("timeout")
|
||||
info = None
|
||||
|
@ -319,7 +353,7 @@ async def handle_render_template(hass, connection, msg):
|
|||
return
|
||||
|
||||
@callback
|
||||
def _template_listener(event, updates):
|
||||
def _template_listener(event: Event, updates: list[TrackTemplateResult]) -> None:
|
||||
nonlocal info
|
||||
track_template_result = updates.pop()
|
||||
result = track_template_result.result
|
||||
|
@ -329,7 +363,7 @@ async def handle_render_template(hass, connection, msg):
|
|||
|
||||
connection.send_message(
|
||||
messages.event_message(
|
||||
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore
|
||||
msg["id"], {"result": result, "listeners": info.listeners} # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -356,7 +390,9 @@ async def handle_render_template(hass, connection, msg):
|
|||
@decorators.websocket_command(
|
||||
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
|
||||
)
|
||||
def handle_entity_source(hass, connection, msg):
|
||||
def handle_entity_source(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle entity source command."""
|
||||
raw_sources = entity.entity_sources(hass)
|
||||
entity_perm = connection.user.permissions.check_entity
|
||||
|
@ -404,7 +440,9 @@ def handle_entity_source(hass, connection, msg):
|
|||
)
|
||||
@decorators.require_admin
|
||||
@decorators.async_response
|
||||
async def handle_subscribe_trigger(hass, connection, msg):
|
||||
async def handle_subscribe_trigger(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe trigger command."""
|
||||
# Circular dep
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
@ -413,7 +451,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
|
|||
trigger_config = await trigger.async_validate_trigger_config(hass, msg["trigger"])
|
||||
|
||||
@callback
|
||||
def forward_triggers(variables, context=None):
|
||||
def forward_triggers(
|
||||
variables: dict[str, Any], context: Context | None = None
|
||||
) -> None:
|
||||
"""Forward events to websocket."""
|
||||
message = messages.event_message(
|
||||
msg["id"], {"variables": variables, "context": context}
|
||||
|
@ -449,7 +489,9 @@ async def handle_subscribe_trigger(hass, connection, msg):
|
|||
)
|
||||
@decorators.require_admin
|
||||
@decorators.async_response
|
||||
async def handle_test_condition(hass, connection, msg):
|
||||
async def handle_test_condition(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle test condition command."""
|
||||
# Circular dep
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
@ -470,7 +512,9 @@ async def handle_test_condition(hass, connection, msg):
|
|||
)
|
||||
@decorators.require_admin
|
||||
@decorators.async_response
|
||||
async def handle_execute_script(hass, connection, msg):
|
||||
async def handle_execute_script(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle execute script command."""
|
||||
# Circular dep
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
|
|
@ -3,48 +3,50 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
from collections.abc import Hashable
|
||||
from typing import Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
|
||||
from . import const, messages
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
if TYPE_CHECKING:
|
||||
from .http import WebSocketAdapter
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, logger, hass, send_message, user, refresh_token):
|
||||
def __init__(
|
||||
self,
|
||||
logger: WebSocketAdapter,
|
||||
hass: HomeAssistant,
|
||||
send_message: Callable[[str | dict[str, Any]], None],
|
||||
user: User,
|
||||
refresh_token: RefreshToken,
|
||||
) -> None:
|
||||
"""Initialize an active connection."""
|
||||
self.logger = logger
|
||||
self.hass = hass
|
||||
self.send_message = send_message
|
||||
self.user = user
|
||||
if refresh_token:
|
||||
self.refresh_token_id = refresh_token.id
|
||||
else:
|
||||
self.refresh_token_id = None
|
||||
|
||||
self.refresh_token_id = refresh_token.id
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
|
||||
def context(self, msg):
|
||||
def context(self, msg: dict[str, Any]) -> Context:
|
||||
"""Return a context."""
|
||||
user = self.user
|
||||
if user is None:
|
||||
return Context()
|
||||
return Context(user_id=user.id)
|
||||
return Context(user_id=self.user.id)
|
||||
|
||||
@callback
|
||||
def send_result(self, msg_id: int, result: Any | None = None) -> None:
|
||||
"""Send a result message."""
|
||||
self.send_message(messages.result_message(msg_id, result))
|
||||
|
||||
async def send_big_result(self, msg_id, result):
|
||||
async def send_big_result(self, msg_id: int, result: Any) -> None:
|
||||
"""Send a result message that would be expensive to JSON serialize."""
|
||||
content = await self.hass.async_add_executor_job(
|
||||
const.JSON_DUMP, messages.result_message(msg_id, result)
|
||||
|
@ -57,7 +59,7 @@ class ActiveConnection:
|
|||
self.send_message(messages.error_message(msg_id, code, message))
|
||||
|
||||
@callback
|
||||
def async_handle(self, msg):
|
||||
def async_handle(self, msg: dict[str, Any]) -> None:
|
||||
"""Handle a single incoming message."""
|
||||
handlers = self.hass.data[const.DOMAIN]
|
||||
|
||||
|
@ -102,13 +104,13 @@ class ActiveConnection:
|
|||
self.last_id = cur_id
|
||||
|
||||
@callback
|
||||
def async_close(self):
|
||||
def async_close(self) -> None:
|
||||
"""Close down connection."""
|
||||
for unsub in self.subscriptions.values():
|
||||
unsub()
|
||||
|
||||
@callback
|
||||
def async_handle_exception(self, msg, err):
|
||||
def async_handle_exception(self, msg: dict[str, Any], err: Exception) -> None:
|
||||
"""Handle an exception while processing a handler."""
|
||||
log_handler = self.logger.error
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"""Websocket constants."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from functools import partial
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Final
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
|
@ -12,37 +14,42 @@ if TYPE_CHECKING:
|
|||
from .connection import ActiveConnection
|
||||
|
||||
|
||||
WebSocketCommandHandler = Callable[[HomeAssistant, "ActiveConnection", dict], None]
|
||||
WebSocketCommandHandler = Callable[
|
||||
[HomeAssistant, "ActiveConnection", Dict[str, Any]], None
|
||||
]
|
||||
AsyncWebSocketCommandHandler = Callable[
|
||||
[HomeAssistant, "ActiveConnection", Dict[str, Any]], Awaitable[None]
|
||||
]
|
||||
|
||||
DOMAIN = "websocket_api"
|
||||
URL = "/api/websocket"
|
||||
PENDING_MSG_PEAK = 512
|
||||
PENDING_MSG_PEAK_TIME = 5
|
||||
MAX_PENDING_MSG = 2048
|
||||
DOMAIN: Final = "websocket_api"
|
||||
URL: Final = "/api/websocket"
|
||||
PENDING_MSG_PEAK: Final = 512
|
||||
PENDING_MSG_PEAK_TIME: Final = 5
|
||||
MAX_PENDING_MSG: Final = 2048
|
||||
|
||||
ERR_ID_REUSE = "id_reuse"
|
||||
ERR_INVALID_FORMAT = "invalid_format"
|
||||
ERR_NOT_FOUND = "not_found"
|
||||
ERR_NOT_SUPPORTED = "not_supported"
|
||||
ERR_HOME_ASSISTANT_ERROR = "home_assistant_error"
|
||||
ERR_UNKNOWN_COMMAND = "unknown_command"
|
||||
ERR_UNKNOWN_ERROR = "unknown_error"
|
||||
ERR_UNAUTHORIZED = "unauthorized"
|
||||
ERR_TIMEOUT = "timeout"
|
||||
ERR_TEMPLATE_ERROR = "template_error"
|
||||
ERR_ID_REUSE: Final = "id_reuse"
|
||||
ERR_INVALID_FORMAT: Final = "invalid_format"
|
||||
ERR_NOT_FOUND: Final = "not_found"
|
||||
ERR_NOT_SUPPORTED: Final = "not_supported"
|
||||
ERR_HOME_ASSISTANT_ERROR: Final = "home_assistant_error"
|
||||
ERR_UNKNOWN_COMMAND: Final = "unknown_command"
|
||||
ERR_UNKNOWN_ERROR: Final = "unknown_error"
|
||||
ERR_UNAUTHORIZED: Final = "unauthorized"
|
||||
ERR_TIMEOUT: Final = "timeout"
|
||||
ERR_TEMPLATE_ERROR: Final = "template_error"
|
||||
|
||||
TYPE_RESULT = "result"
|
||||
TYPE_RESULT: Final = "result"
|
||||
|
||||
# Define the possible errors that occur when connections are cancelled.
|
||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||
# that futures.CancelledErrors can also occur in some situations.
|
||||
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)
|
||||
CANCELLATION_ERRORS: Final = (asyncio.CancelledError, futures.CancelledError)
|
||||
|
||||
# Event types
|
||||
SIGNAL_WEBSOCKET_CONNECTED = "websocket_connected"
|
||||
SIGNAL_WEBSOCKET_DISCONNECTED = "websocket_disconnected"
|
||||
SIGNAL_WEBSOCKET_CONNECTED: Final = "websocket_connected"
|
||||
SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
|
||||
|
||||
# Data used to store the current connection list
|
||||
DATA_CONNECTIONS = f"{DOMAIN}.connections"
|
||||
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
|
||||
|
||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder, allow_nan=False)
|
||||
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, allow_nan=False)
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
|
@ -12,10 +13,13 @@ from homeassistant.exceptions import Unauthorized
|
|||
from . import const, messages
|
||||
from .connection import ActiveConnection
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
|
||||
async def _handle_async_response(func, hass, connection, msg):
|
||||
async def _handle_async_response(
|
||||
func: const.AsyncWebSocketCommandHandler,
|
||||
hass: HomeAssistant,
|
||||
connection: ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a response and handle exception."""
|
||||
try:
|
||||
await func(hass, connection, msg)
|
||||
|
@ -24,13 +28,15 @@ async def _handle_async_response(func, hass, connection, msg):
|
|||
|
||||
|
||||
def async_response(
|
||||
func: Callable[[HomeAssistant, ActiveConnection, dict], Awaitable[None]]
|
||||
func: const.AsyncWebSocketCommandHandler,
|
||||
) -> const.WebSocketCommandHandler:
|
||||
"""Decorate an async function to handle WebSocket API messages."""
|
||||
|
||||
@callback
|
||||
@wraps(func)
|
||||
def schedule_handler(hass, connection, msg):
|
||||
def schedule_handler(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Schedule the handler."""
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responders
|
||||
|
@ -43,7 +49,9 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
|
|||
"""Websocket decorator to require user to be an admin."""
|
||||
|
||||
@wraps(func)
|
||||
def with_admin(hass, connection, msg):
|
||||
def with_admin(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Check admin and call function."""
|
||||
user = connection.user
|
||||
|
||||
|
@ -56,34 +64,32 @@ def require_admin(func: const.WebSocketCommandHandler) -> const.WebSocketCommand
|
|||
|
||||
|
||||
def ws_require_user(
|
||||
only_owner=False,
|
||||
only_system_user=False,
|
||||
allow_system_user=True,
|
||||
only_active_user=True,
|
||||
only_inactive_user=False,
|
||||
):
|
||||
only_owner: bool = False,
|
||||
only_system_user: bool = False,
|
||||
allow_system_user: bool = True,
|
||||
only_active_user: bool = True,
|
||||
only_inactive_user: bool = False,
|
||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||
"""Decorate function validating login user exist in current WS connection.
|
||||
|
||||
Will write out error message if not authenticated.
|
||||
"""
|
||||
|
||||
def validator(func):
|
||||
def validator(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||
"""Decorate func."""
|
||||
|
||||
@wraps(func)
|
||||
def check_current_user(hass, connection, msg):
|
||||
def check_current_user(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Check current user."""
|
||||
|
||||
def output_error(message_id, message):
|
||||
def output_error(message_id: str, message: str) -> None:
|
||||
"""Output error message."""
|
||||
connection.send_message(
|
||||
messages.error_message(msg["id"], message_id, message)
|
||||
)
|
||||
|
||||
if connection.user is None:
|
||||
output_error("no_user", "Not authenticated as a user")
|
||||
return
|
||||
|
||||
if only_owner and not connection.user.is_owner:
|
||||
output_error("only_owner", "Only allowed as owner")
|
||||
return
|
||||
|
@ -112,16 +118,16 @@ def ws_require_user(
|
|||
|
||||
|
||||
def websocket_command(
|
||||
schema: dict,
|
||||
schema: dict[vol.Marker, Any],
|
||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||
"""Tag a function as a websocket command."""
|
||||
command = schema["type"]
|
||||
|
||||
def decorate(func):
|
||||
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||
"""Decorate ws command function."""
|
||||
# pylint: disable=protected-access
|
||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema)
|
||||
func._ws_command = command
|
||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||
func._ws_command = command # type: ignore[attr-defined]
|
||||
return func
|
||||
|
||||
return decorate
|
||||
|
|
|
@ -2,15 +2,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import datetime as dt
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import WSMsgType, web
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
from .auth import AuthPhase, auth_required_message
|
||||
|
@ -27,16 +30,15 @@ from .const import (
|
|||
from .error import Disconnect
|
||||
from .messages import message_to_json
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
_WS_LOGGER = logging.getLogger(f"{__name__}.connection")
|
||||
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||
|
||||
|
||||
class WebsocketAPIView(HomeAssistantView):
|
||||
"""View to serve a websockets endpoint."""
|
||||
|
||||
name = "websocketapi"
|
||||
url = URL
|
||||
requires_auth = False
|
||||
name: str = "websocketapi"
|
||||
url: str = URL
|
||||
requires_auth: bool = False
|
||||
|
||||
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle an incoming websocket connection."""
|
||||
|
@ -46,7 +48,7 @@ class WebsocketAPIView(HomeAssistantView):
|
|||
class WebSocketAdapter(logging.LoggerAdapter):
|
||||
"""Add connection id to websocket messages."""
|
||||
|
||||
def process(self, msg, kwargs):
|
||||
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
|
||||
"""Add connid to websocket log messages."""
|
||||
return f'[{self.extra["connid"]}] {msg}', kwargs
|
||||
|
||||
|
@ -54,20 +56,21 @@ class WebSocketAdapter(logging.LoggerAdapter):
|
|||
class WebSocketHandler:
|
||||
"""Handle an active websocket client connection."""
|
||||
|
||||
def __init__(self, hass, request):
|
||||
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
|
||||
"""Initialize an active connection."""
|
||||
self.hass = hass
|
||||
self.request = request
|
||||
self.wsock: web.WebSocketResponse | None = None
|
||||
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
|
||||
self._handle_task = None
|
||||
self._writer_task = None
|
||||
self._handle_task: asyncio.Task | None = None
|
||||
self._writer_task: asyncio.Task | None = None
|
||||
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
|
||||
self._peak_checker_unsub = None
|
||||
self._peak_checker_unsub: Callable[[], None] | None = None
|
||||
|
||||
async def _writer(self):
|
||||
async def _writer(self) -> None:
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
assert self.wsock is not None
|
||||
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
message = await self._to_write.get()
|
||||
|
@ -78,12 +81,12 @@ class WebSocketHandler:
|
|||
await self.wsock.send_str(message)
|
||||
|
||||
# Clean up the peaker checker when we shut down the writer
|
||||
if self._peak_checker_unsub:
|
||||
if self._peak_checker_unsub is not None:
|
||||
self._peak_checker_unsub()
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
@callback
|
||||
def _send_message(self, message):
|
||||
def _send_message(self, message: str | dict[str, Any]) -> None:
|
||||
"""Send a message to the client.
|
||||
|
||||
Closes connection if the client is not reading the messages.
|
||||
|
@ -114,7 +117,7 @@ class WebSocketHandler:
|
|||
)
|
||||
|
||||
@callback
|
||||
def _check_write_peak(self, _):
|
||||
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
|
||||
"""Check that we are no longer above the write peak."""
|
||||
self._peak_checker_unsub = None
|
||||
|
||||
|
@ -129,10 +132,12 @@ class WebSocketHandler:
|
|||
self._cancel()
|
||||
|
||||
@callback
|
||||
def _cancel(self):
|
||||
def _cancel(self) -> None:
|
||||
"""Cancel the connection."""
|
||||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
if self._handle_task is not None:
|
||||
self._handle_task.cancel()
|
||||
if self._writer_task is not None:
|
||||
self._writer_task.cancel()
|
||||
|
||||
async def async_handle(self) -> web.WebSocketResponse:
|
||||
"""Handle a websocket response."""
|
||||
|
@ -143,7 +148,7 @@ class WebSocketHandler:
|
|||
self._handle_task = asyncio.current_task()
|
||||
|
||||
@callback
|
||||
def handle_hass_stop(event):
|
||||
def handle_hass_stop(event: Event) -> None:
|
||||
"""Cancel this connection."""
|
||||
self._cancel()
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -17,28 +17,27 @@ from homeassistant.util.yaml.loader import JSON_TYPE
|
|||
|
||||
from . import const
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
# mypy: allow-untyped-defs
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
# Minimal requirements of a message
|
||||
MINIMAL_MESSAGE_SCHEMA = vol.Schema(
|
||||
MINIMAL_MESSAGE_SCHEMA: Final = vol.Schema(
|
||||
{vol.Required("id"): cv.positive_int, vol.Required("type"): cv.string},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
# Base schema to extend by message handlers
|
||||
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({vol.Required("id"): cv.positive_int})
|
||||
BASE_COMMAND_MESSAGE_SCHEMA: Final = vol.Schema({vol.Required("id"): cv.positive_int})
|
||||
|
||||
IDEN_TEMPLATE = "__IDEN__"
|
||||
IDEN_JSON_TEMPLATE = '"__IDEN__"'
|
||||
IDEN_TEMPLATE: Final = "__IDEN__"
|
||||
IDEN_JSON_TEMPLATE: Final = '"__IDEN__"'
|
||||
|
||||
|
||||
def result_message(iden: int, result: Any = None) -> dict:
|
||||
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
|
||||
"""Return a success result message."""
|
||||
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
|
||||
|
||||
|
||||
def error_message(iden: int, code: str, message: str) -> dict:
|
||||
def error_message(iden: int | None, code: str, message: str) -> dict[str, Any]:
|
||||
"""Return an error result message."""
|
||||
return {
|
||||
"id": iden,
|
||||
|
@ -48,7 +47,7 @@ def error_message(iden: int, code: str, message: str) -> dict:
|
|||
}
|
||||
|
||||
|
||||
def event_message(iden: JSON_TYPE, event: Any) -> dict:
|
||||
def event_message(iden: JSON_TYPE, event: Any) -> dict[str, Any]:
|
||||
"""Return an event message."""
|
||||
return {"id": iden, "type": "event", "event": event}
|
||||
|
||||
|
@ -75,7 +74,7 @@ def _cached_event_message(event: Event) -> str:
|
|||
return message_to_json(event_message(IDEN_TEMPLATE, event))
|
||||
|
||||
|
||||
def message_to_json(message: Any) -> str:
|
||||
def message_to_json(message: dict[str, Any]) -> str:
|
||||
"""Serialize a websocket message to json."""
|
||||
try:
|
||||
return const.JSON_DUMP(message)
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
Separate file to avoid circular imports.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
from homeassistant.components.frontend import EVENT_PANELS_UPDATED
|
||||
from homeassistant.components.lovelace.const import EVENT_LOVELACE_UPDATED
|
||||
from homeassistant.components.persistent_notification import (
|
||||
|
@ -22,7 +26,7 @@ from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
|
|||
|
||||
# These are events that do not contain any sensitive data
|
||||
# Except for state_changed, which is handled accordingly.
|
||||
SUBSCRIBE_ALLOWLIST = {
|
||||
SUBSCRIBE_ALLOWLIST: Final[set[str]] = {
|
||||
EVENT_AREA_REGISTRY_UPDATED,
|
||||
EVENT_COMPONENT_LOADED,
|
||||
EVENT_CORE_CONFIG_UPDATE,
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
"""Entity to track connections to websocket API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.sensor import SensorEntity
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
DATA_CONNECTIONS,
|
||||
|
@ -9,10 +14,13 @@ from .const import (
|
|||
SIGNAL_WEBSOCKET_DISCONNECTED,
|
||||
)
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
|
||||
async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
|
||||
async def async_setup_platform(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
discovery_info: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Set up the API streams platform."""
|
||||
entity = APICount()
|
||||
|
||||
|
@ -22,11 +30,11 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
|||
class APICount(SensorEntity):
|
||||
"""Entity to represent how many people are connected to the stream API."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the API count."""
|
||||
self.count = 0
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Added to hass."""
|
||||
self.async_on_remove(
|
||||
self.hass.helpers.dispatcher.async_dispatcher_connect(
|
||||
|
@ -40,21 +48,21 @@ class APICount(SensorEntity):
|
|||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Return name of entity."""
|
||||
return "Connected clients"
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
def state(self) -> int:
|
||||
"""Return current API count."""
|
||||
return self.count
|
||||
|
||||
@property
|
||||
def unit_of_measurement(self):
|
||||
def unit_of_measurement(self) -> str:
|
||||
"""Return the unit of measurement."""
|
||||
return "clients"
|
||||
|
||||
@callback
|
||||
def _update_count(self):
|
||||
def _update_count(self) -> None:
|
||||
self.count = self.hass.data.get(DATA_CONNECTIONS, 0)
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Test WebSocket Connection class."""
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -8,6 +9,8 @@ from homeassistant import exceptions
|
|||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api import const
|
||||
|
||||
from tests.common import MockUser
|
||||
|
||||
|
||||
async def test_send_big_result(hass, websocket_client):
|
||||
"""Test sending big results over the WS."""
|
||||
|
@ -31,8 +34,10 @@ async def test_send_big_result(hass, websocket_client):
|
|||
async def test_exception_handling():
|
||||
"""Test handling of exceptions."""
|
||||
send_messages = []
|
||||
user = MockUser()
|
||||
refresh_token = Mock()
|
||||
conn = websocket_api.ActiveConnection(
|
||||
logging.getLogger(__name__), None, send_messages.append, None, None
|
||||
logging.getLogger(__name__), None, send_messages.append, user, refresh_token
|
||||
)
|
||||
|
||||
for (exc, code, err) in (
|
||||
|
|
Loading…
Reference in New Issue