189 lines
6.1 KiB
Python
189 lines
6.1 KiB
Python
"""Helper to track the current http request."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Awaitable, Callable
|
|
from contextvars import ContextVar
|
|
from http import HTTPStatus
|
|
import logging
|
|
from typing import Any, Final
|
|
|
|
from aiohttp import web
|
|
from aiohttp.typedefs import LooseHeaders
|
|
from aiohttp.web import AppKey, Request
|
|
from aiohttp.web_exceptions import (
|
|
HTTPBadRequest,
|
|
HTTPInternalServerError,
|
|
HTTPUnauthorized,
|
|
)
|
|
from aiohttp.web_urldispatcher import AbstractResource, AbstractRoute
|
|
import voluptuous as vol
|
|
|
|
from homeassistant import exceptions
|
|
from homeassistant.const import CONTENT_TYPE_JSON
|
|
from homeassistant.core import Context, HomeAssistant, is_callback
|
|
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
|
|
|
|
from .json import find_paths_unserializable_data, json_bytes, json_dumps
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
type AllowCorsType = Callable[[AbstractRoute | AbstractResource], None]
|
|
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
|
KEY_ALLOW_ALL_CORS = AppKey[AllowCorsType]("allow_all_cors")
|
|
KEY_ALLOW_CONFIGURED_CORS = AppKey[AllowCorsType]("allow_configured_cors")
|
|
KEY_HASS: AppKey[HomeAssistant] = AppKey("hass")
|
|
|
|
current_request: ContextVar[Request | None] = ContextVar(
|
|
"current_request", default=None
|
|
)
|
|
|
|
|
|
def request_handler_factory(
|
|
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
|
|
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
|
"""Wrap the handler classes."""
|
|
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
|
|
assert is_coroutinefunction or is_callback(handler), (
|
|
"Handler should be a coroutine or a callback."
|
|
)
|
|
|
|
async def handle(request: web.Request) -> web.StreamResponse:
|
|
"""Handle incoming request."""
|
|
if hass.is_stopping:
|
|
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
|
|
|
|
authenticated = request.get(KEY_AUTHENTICATED, False)
|
|
|
|
if view.requires_auth and not authenticated:
|
|
raise HTTPUnauthorized
|
|
|
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
|
_LOGGER.debug(
|
|
"Serving %s to %s (auth: %s)",
|
|
request.path,
|
|
request.remote,
|
|
authenticated,
|
|
)
|
|
|
|
try:
|
|
if is_coroutinefunction:
|
|
result = await handler(request, **request.match_info)
|
|
else:
|
|
result = handler(request, **request.match_info)
|
|
except vol.Invalid as err:
|
|
raise HTTPBadRequest from err
|
|
except exceptions.ServiceNotFound as err:
|
|
raise HTTPInternalServerError from err
|
|
except exceptions.Unauthorized as err:
|
|
raise HTTPUnauthorized from err
|
|
|
|
if isinstance(result, web.StreamResponse):
|
|
# The method handler returned a ready-made Response, how nice of it
|
|
return result
|
|
|
|
status_code = HTTPStatus.OK
|
|
if isinstance(result, tuple):
|
|
result, status_code = result
|
|
|
|
if isinstance(result, bytes):
|
|
return web.Response(body=result, status=status_code)
|
|
|
|
if isinstance(result, str):
|
|
return web.Response(text=result, status=status_code)
|
|
|
|
if result is None:
|
|
return web.Response(body=b"", status=status_code)
|
|
|
|
raise TypeError(
|
|
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
|
|
)
|
|
|
|
return handle
|
|
|
|
|
|
class HomeAssistantView:
|
|
"""Base view for all views."""
|
|
|
|
url: str | None = None
|
|
extra_urls: list[str] = []
|
|
# Views inheriting from this class can override this
|
|
requires_auth = True
|
|
cors_allowed = False
|
|
|
|
@staticmethod
|
|
def context(request: web.Request) -> Context:
|
|
"""Generate a context from a request."""
|
|
if (user := request.get("hass_user")) is None:
|
|
return Context()
|
|
|
|
return Context(user_id=user.id)
|
|
|
|
@staticmethod
|
|
def json(
|
|
result: Any,
|
|
status_code: HTTPStatus | int = HTTPStatus.OK,
|
|
headers: LooseHeaders | None = None,
|
|
) -> web.Response:
|
|
"""Return a JSON response."""
|
|
try:
|
|
msg = json_bytes(result)
|
|
except JSON_ENCODE_EXCEPTIONS as err:
|
|
_LOGGER.error(
|
|
"Unable to serialize to JSON. Bad data found at %s",
|
|
format_unserializable_data(
|
|
find_paths_unserializable_data(result, dump=json_dumps)
|
|
),
|
|
)
|
|
raise HTTPInternalServerError from err
|
|
response = web.Response(
|
|
body=msg,
|
|
content_type=CONTENT_TYPE_JSON,
|
|
status=int(status_code),
|
|
headers=headers,
|
|
zlib_executor_size=32768,
|
|
)
|
|
response.enable_compression()
|
|
return response
|
|
|
|
def json_message(
|
|
self,
|
|
message: str,
|
|
status_code: HTTPStatus | int = HTTPStatus.OK,
|
|
message_code: str | None = None,
|
|
headers: LooseHeaders | None = None,
|
|
) -> web.Response:
|
|
"""Return a JSON message response."""
|
|
data = {"message": message}
|
|
if message_code is not None:
|
|
data["code"] = message_code
|
|
return self.json(data, status_code, headers=headers)
|
|
|
|
def register(
|
|
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
|
|
) -> None:
|
|
"""Register the view with a router."""
|
|
assert self.url is not None, "No url set for view"
|
|
urls = [self.url, *self.extra_urls]
|
|
routes: list[AbstractRoute] = []
|
|
|
|
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
|
if not (handler := getattr(self, method, None)):
|
|
continue
|
|
|
|
handler = request_handler_factory(hass, self, handler)
|
|
|
|
routes.extend(router.add_route(method, url, handler) for url in urls)
|
|
|
|
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
|
if self.cors_allowed:
|
|
allow_cors = app.get(KEY_ALLOW_ALL_CORS)
|
|
else:
|
|
allow_cors = app.get(KEY_ALLOW_CONFIGURED_CORS)
|
|
|
|
if allow_cors:
|
|
for route in routes:
|
|
allow_cors(route)
|