Improve api typing (#108307)

pull/108306/head^2
Marc Mueller 2024-01-18 23:45:15 +01:00 committed by GitHub
parent a670ac25fd
commit 7c6fe31505
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 52 additions and 33 deletions

View File

@ -4,6 +4,7 @@ from asyncio import shield, timeout
from functools import lru_cache
from http import HTTPStatus
import logging
from typing import Any
from aiohttp import web
from aiohttp.web_exceptions import HTTPBadRequest
@ -30,7 +31,7 @@ from homeassistant.const import (
URL_API_TEMPLATE,
)
import homeassistant.core as ha
from homeassistant.core import HomeAssistant
from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import (
InvalidEntityFormatError,
InvalidStateError,
@ -92,7 +93,7 @@ class APIStatusView(HomeAssistantView):
name = "api:status"
@ha.callback
def get(self, request):
def get(self, request: web.Request) -> web.Response:
"""Retrieve if API is running."""
return self.json_message("API running.")
@ -124,14 +125,15 @@ class APIEventStream(HomeAssistantView):
@require_admin
async def get(self, request):
"""Provide a streaming interface for the event bus."""
hass = request.app["hass"]
hass: HomeAssistant = request.app["hass"]
stop_obj = object()
to_write = asyncio.Queue()
to_write: asyncio.Queue[object | str] = asyncio.Queue()
if restrict := request.query.get("restrict"):
restrict = restrict.split(",") + [EVENT_HOMEASSISTANT_STOP]
restrict: list[str] | None = None
if restrict_str := request.query.get("restrict"):
restrict = restrict_str.split(",") + [EVENT_HOMEASSISTANT_STOP]
async def forward_events(event):
async def forward_events(event: Event) -> None:
"""Forward events to the open request."""
if restrict and event.event_type not in restrict:
return
@ -188,9 +190,10 @@ class APIConfigView(HomeAssistantView):
name = "api:config"
@ha.callback
def get(self, request):
def get(self, request: web.Request) -> web.Response:
"""Get current configuration."""
return self.json(request.app["hass"].config.as_dict())
hass: HomeAssistant = request.app["hass"]
return self.json(hass.config.as_dict())
class APIStatesView(HomeAssistantView):
@ -243,9 +246,10 @@ class APIEntityStateView(HomeAssistantView):
)
return self.json_message("Entity not found.", HTTPStatus.NOT_FOUND)
async def post(self, request, entity_id):
async def post(self, request: web.Request, entity_id: str) -> web.Response:
"""Update state of entity."""
if not request["hass_user"].is_admin:
user: User = request["hass_user"]
if not user.is_admin:
raise Unauthorized(entity_id=entity_id)
hass: HomeAssistant = request.app["hass"]
try:
@ -275,18 +279,20 @@ class APIEntityStateView(HomeAssistantView):
# Read the state back for our response
status_code = HTTPStatus.CREATED if is_new_state else HTTPStatus.OK
resp = self.json(hass.states.get(entity_id).as_dict(), status_code)
assert (state := hass.states.get(entity_id))
resp = self.json(state.as_dict(), status_code)
resp.headers.add("Location", f"/api/states/{entity_id}")
return resp
@ha.callback
def delete(self, request, entity_id):
def delete(self, request: web.Request, entity_id: str) -> web.Response:
"""Remove entity."""
if not request["hass_user"].is_admin:
raise Unauthorized(entity_id=entity_id)
if request.app["hass"].states.async_remove(entity_id):
hass: HomeAssistant = request.app["hass"]
if hass.states.async_remove(entity_id):
return self.json_message("Entity removed.")
return self.json_message("Entity not found.", HTTPStatus.NOT_FOUND)
@ -298,9 +304,10 @@ class APIEventListenersView(HomeAssistantView):
name = "api:event-listeners"
@ha.callback
def get(self, request):
def get(self, request: web.Request) -> web.Response:
"""Get event listeners."""
return self.json(async_events_json(request.app["hass"]))
hass: HomeAssistant = request.app["hass"]
return self.json(async_events_json(hass))
class APIEventView(HomeAssistantView):
@ -310,11 +317,11 @@ class APIEventView(HomeAssistantView):
name = "api:event"
@require_admin
async def post(self, request, event_type):
async def post(self, request: web.Request, event_type: str) -> web.Response:
"""Fire events."""
body = await request.text()
try:
event_data = json_loads(body) if body else None
event_data: Any = json_loads(body) if body else None
except ValueError:
return self.json_message(
"Event data should be valid JSON.", HTTPStatus.BAD_REQUEST
@ -327,14 +334,15 @@ class APIEventView(HomeAssistantView):
# Special case handling for event STATE_CHANGED
# We will try to convert state dicts back to State objects
if event_type == ha.EVENT_STATE_CHANGED and event_data:
if event_type == EVENT_STATE_CHANGED and event_data:
for key in ("old_state", "new_state"):
state = ha.State.from_dict(event_data.get(key))
state = ha.State.from_dict(event_data[key])
if state:
event_data[key] = state
request.app["hass"].bus.async_fire(
hass: HomeAssistant = request.app["hass"]
hass.bus.async_fire(
event_type, event_data, ha.EventOrigin.remote, self.context(request)
)
@ -347,9 +355,10 @@ class APIServicesView(HomeAssistantView):
url = URL_API_SERVICES
name = "api:services"
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Get registered services."""
services = await async_services_json(request.app["hass"])
hass: HomeAssistant = request.app["hass"]
services = await async_services_json(hass)
return self.json(services)
@ -359,12 +368,14 @@ class APIDomainServicesView(HomeAssistantView):
url = "/api/services/{domain}/{service}"
name = "api:domain-services"
async def post(self, request, domain, service):
async def post(
self, request: web.Request, domain: str, service: str
) -> web.Response:
"""Call a service.
Returns a list of changed states.
"""
hass: ha.HomeAssistant = request.app["hass"]
hass: HomeAssistant = request.app["hass"]
body = await request.text()
try:
data = json_loads(body) if body else None
@ -384,14 +395,20 @@ class APIDomainServicesView(HomeAssistantView):
changed_states.append(state.json_fragment)
cancel_listen = hass.bus.async_listen(
EVENT_STATE_CHANGED, _async_save_changed_entities, run_immediately=True
EVENT_STATE_CHANGED,
_async_save_changed_entities, # type: ignore[arg-type]
run_immediately=True,
)
try:
# shield the service call from cancellation on connection drop
await shield(
hass.services.async_call(
domain, service, data, blocking=True, context=context
domain,
service,
data, # type: ignore[arg-type]
blocking=True,
context=context,
)
)
except (vol.Invalid, ServiceNotFound) as ex:
@ -409,9 +426,10 @@ class APIComponentsView(HomeAssistantView):
name = "api:components"
@ha.callback
def get(self, request):
def get(self, request: web.Request) -> web.Response:
"""Get current loaded components."""
return self.json(request.app["hass"].config.components)
hass: HomeAssistant = request.app["hass"]
return self.json(hass.config.components)
@lru_cache
@ -427,7 +445,7 @@ class APITemplateView(HomeAssistantView):
name = "api:template"
@require_admin
async def post(self, request):
async def post(self, request: web.Request) -> web.Response:
"""Render a template."""
try:
data = await request.json()
@ -448,17 +466,18 @@ class APIErrorLog(HomeAssistantView):
@require_admin
async def get(self, request):
"""Retrieve API error log."""
return web.FileResponse(request.app["hass"].data[DATA_LOGGING])
hass: HomeAssistant = request.app["hass"]
return web.FileResponse(hass.data[DATA_LOGGING])
async def async_services_json(hass):
async def async_services_json(hass: HomeAssistant) -> list[dict[str, Any]]:
"""Generate services data to JSONify."""
descriptions = await async_get_all_descriptions(hass)
return [{"domain": key, "services": value} for key, value in descriptions.items()]
@ha.callback
def async_events_json(hass):
def async_events_json(hass: HomeAssistant) -> list[dict[str, Any]]:
"""Generate event data to JSONify."""
return [
{"event": key, "listener_count": value}