Add missing type hints in http component (#50411)

pull/50437/head
Ruslan Sayfutdinov 2021-05-10 22:30:47 +01:00 committed by GitHub
parent 85f758380a
commit ce15f28642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 245 additions and 160 deletions

View File

@ -6,16 +6,18 @@ from ipaddress import ip_network
import logging
import os
import ssl
from typing import Any, Optional, cast
from typing import Any, Final, Optional, TypedDict, cast
from aiohttp import web
from aiohttp.web_exceptions import HTTPMovedPermanently
from aiohttp.typedefs import StrOrURL
from aiohttp.web_exceptions import HTTPMovedPermanently, HTTPRedirection
import voluptuous as vol
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
from homeassistant.core import Event, HomeAssistant
from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.setup import async_start_setup, async_when_setup_or_start
import homeassistant.util as hass_util
@ -29,44 +31,42 @@ from .forwarded import async_setup_forwarded
from .request_context import setup_request_context
from .security_filter import setup_security_filter
from .static import CACHE_HEADERS, CachingStaticResource
from .view import HomeAssistantView # noqa: F401
from .view import HomeAssistantView
from .web_runner import HomeAssistantTCPSite
# mypy: allow-untyped-defs, no-check-untyped-defs
DOMAIN: Final = "http"
DOMAIN = "http"
CONF_SERVER_HOST: Final = "server_host"
CONF_SERVER_PORT: Final = "server_port"
CONF_BASE_URL: Final = "base_url"
CONF_SSL_CERTIFICATE: Final = "ssl_certificate"
CONF_SSL_PEER_CERTIFICATE: Final = "ssl_peer_certificate"
CONF_SSL_KEY: Final = "ssl_key"
CONF_CORS_ORIGINS: Final = "cors_allowed_origins"
CONF_USE_X_FORWARDED_FOR: Final = "use_x_forwarded_for"
CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
CONF_SSL_PROFILE: Final = "ssl_profile"
CONF_SERVER_HOST = "server_host"
CONF_SERVER_PORT = "server_port"
CONF_BASE_URL = "base_url"
CONF_SSL_CERTIFICATE = "ssl_certificate"
CONF_SSL_PEER_CERTIFICATE = "ssl_peer_certificate"
CONF_SSL_KEY = "ssl_key"
CONF_CORS_ORIGINS = "cors_allowed_origins"
CONF_USE_X_FORWARDED_FOR = "use_x_forwarded_for"
CONF_TRUSTED_PROXIES = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD = "login_attempts_threshold"
CONF_IP_BAN_ENABLED = "ip_ban_enabled"
CONF_SSL_PROFILE = "ssl_profile"
SSL_MODERN: Final = "modern"
SSL_INTERMEDIATE: Final = "intermediate"
SSL_MODERN = "modern"
SSL_INTERMEDIATE = "intermediate"
_LOGGER: Final = logging.getLogger(__name__)
_LOGGER = logging.getLogger(__name__)
DEFAULT_DEVELOPMENT = "0"
DEFAULT_DEVELOPMENT: Final = "0"
# Cast to be able to load custom cards.
# My to be able to check url and version info.
DEFAULT_CORS = ["https://cast.home-assistant.io"]
NO_LOGIN_ATTEMPT_THRESHOLD = -1
DEFAULT_CORS: Final[list[str]] = ["https://cast.home-assistant.io"]
NO_LOGIN_ATTEMPT_THRESHOLD: Final = -1
MAX_CLIENT_SIZE: int = 1024 ** 2 * 16
MAX_CLIENT_SIZE: Final = 1024 ** 2 * 16
STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1
SAVE_DELAY = 180
STORAGE_KEY: Final = DOMAIN
STORAGE_VERSION: Final = 1
SAVE_DELAY: Final = 180
HTTP_SCHEMA = vol.All(
HTTP_SCHEMA: Final = vol.All(
cv.deprecated(CONF_BASE_URL),
vol.Schema(
{
@ -96,7 +96,24 @@ HTTP_SCHEMA = vol.All(
),
)
CONFIG_SCHEMA = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
CONFIG_SCHEMA: Final = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
class ConfData(TypedDict, total=False):
"""Typed dict for config data."""
server_host: list[str]
server_port: int
base_url: str
ssl_certificate: str
ssl_peer_certificate: str
ssl_key: str
cors_allowed_origins: list[str]
use_x_forwarded_for: bool
trusted_proxies: list[str]
login_attempts_threshold: int
ip_ban_enabled: bool
ssl_profile: str
@bind_hass
@ -113,8 +130,8 @@ class ApiConfig:
self,
local_ip: str,
host: str,
port: int | None = SERVER_PORT,
use_ssl: bool = False,
port: int,
use_ssl: bool,
) -> None:
"""Initialize a new API config object."""
self.local_ip = local_ip
@ -123,12 +140,12 @@ class ApiConfig:
self.use_ssl = use_ssl
async def async_setup(hass, config):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN)
conf: ConfData | None = config.get(DOMAIN)
if conf is None:
conf = HTTP_SCHEMA({})
conf = cast(ConfData, HTTP_SCHEMA({}))
server_host = conf.get(CONF_SERVER_HOST)
server_port = conf[CONF_SERVER_PORT]
@ -137,7 +154,7 @@ async def async_setup(hass, config):
ssl_key = conf.get(CONF_SSL_KEY)
cors_origins = conf[CONF_CORS_ORIGINS]
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES, [])
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES) or []
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
ssl_profile = conf[CONF_SSL_PROFILE]
@ -165,6 +182,8 @@ async def async_setup(hass, config):
"""Start the server."""
with async_start_setup(hass, ["http"]):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
# We already checked it's not None.
assert conf is not None
await start_http_server_and_save_config(hass, dict(conf), server)
async_when_setup_or_start(hass, "frontend", start_server)
@ -190,19 +209,19 @@ class HomeAssistantHTTP:
def __init__(
self,
hass,
ssl_certificate,
ssl_peer_certificate,
ssl_key,
server_host,
server_port,
cors_origins,
use_x_forwarded_for,
trusted_proxies,
login_threshold,
is_ban_enabled,
ssl_profile,
):
hass: HomeAssistant,
ssl_certificate: str | None,
ssl_peer_certificate: str | None,
ssl_key: str | None,
server_host: list[str] | None,
server_port: int,
cors_origins: list[str],
use_x_forwarded_for: bool,
trusted_proxies: list[str],
login_threshold: int,
is_ban_enabled: bool,
ssl_profile: str,
) -> None:
"""Initialize the HTTP Home Assistant server."""
app = self.app = web.Application(
middlewares=[], client_max_size=MAX_CLIENT_SIZE
@ -237,10 +256,10 @@ class HomeAssistantHTTP:
self.is_ban_enabled = is_ban_enabled
self.ssl_profile = ssl_profile
self._handler = None
self.runner = None
self.site = None
self.runner: web.AppRunner | None = None
self.site: HomeAssistantTCPSite | None = None
def register_view(self, view):
def register_view(self, view: HomeAssistantView) -> None:
"""Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView.
@ -261,7 +280,13 @@ class HomeAssistantHTTP:
view.register(self.app, self.app.router)
def register_redirect(self, url, redirect_to, *, redirect_exc=HTTPMovedPermanently):
def register_redirect(
self,
url: str,
redirect_to: StrOrURL,
*,
redirect_exc: type[HTTPRedirection] = HTTPMovedPermanently,
) -> None:
"""Register a redirect with the server.
If given this must be either a string or callable. In case of a
@ -271,38 +296,39 @@ class HomeAssistantHTTP:
rule syntax.
"""
async def redirect(request):
async def redirect(request: web.Request) -> web.StreamResponse:
"""Redirect to location."""
raise redirect_exc(redirect_to)
# Should be instance of aiohttp.web_exceptions._HTTPMove.
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
self.app.router.add_route("GET", url, redirect)
def register_static_path(self, url_path, path, cache_headers=True):
def register_static_path(
self, url_path: str, path: str, cache_headers: bool = True
) -> web.FileResponse | None:
"""Register a folder or file to serve as a static path."""
if os.path.isdir(path):
if cache_headers:
resource = CachingStaticResource
resource: type[
CachingStaticResource | web.StaticResource
] = CachingStaticResource
else:
resource = web.StaticResource
self.app.router.register_resource(resource(url_path, path))
return
return None
if cache_headers:
async def serve_file(request):
"""Serve file from disk."""
async def serve_file(request: web.Request) -> web.FileResponse:
"""Serve file from disk."""
if cache_headers:
return web.FileResponse(path, headers=CACHE_HEADERS)
else:
async def serve_file(request):
"""Serve file from disk."""
return web.FileResponse(path)
return web.FileResponse(path)
self.app.router.add_route("GET", url_path, serve_file)
return None
async def start(self):
async def start(self) -> None:
"""Start the aiohttp server."""
context: ssl.SSLContext | None
if self.ssl_certificate:
try:
if self.ssl_profile == SSL_INTERMEDIATE:
@ -334,7 +360,7 @@ class HomeAssistantHTTP:
# This will now raise a RunTimeError.
# To work around this we now prevent the router from getting frozen
# pylint: disable=protected-access
self.app._router.freeze = lambda: None
self.app._router.freeze = lambda: None # type: ignore[assignment]
self.runner = web.AppRunner(self.app)
await self.runner.setup()
@ -351,17 +377,19 @@ class HomeAssistantHTTP:
_LOGGER.info("Now listening on port %d", self.server_port)
async def stop(self):
async def stop(self) -> None:
"""Stop the aiohttp server."""
await self.site.stop()
await self.runner.cleanup()
if self.site is not None:
await self.site.stop()
if self.runner is not None:
await self.runner.cleanup()
async def start_http_server_and_save_config(
hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP
) -> None:
"""Startup the http server and save the config."""
await server.start() # type: ignore
await server.start()
# If we are set up successful, we store the HTTP settings for safe mode.
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)

View File

@ -1,28 +1,33 @@
"""Authentication for HTTP component."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import timedelta
import logging
import secrets
from typing import Final
from urllib.parse import unquote
from aiohttp import hdrs
from aiohttp.web import middleware
from aiohttp.web import Application, Request, StreamResponse, middleware
import jwt
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
DATA_API_PASSWORD = "api_password"
DATA_SIGN_SECRET = "http.auth.sign_secret"
SIGN_QUERY_PARAM = "authSig"
DATA_API_PASSWORD: Final = "api_password"
DATA_SIGN_SECRET: Final = "http.auth.sign_secret"
SIGN_QUERY_PARAM: Final = "authSig"
@callback
def async_sign_path(hass, refresh_token_id, path, expiration):
def async_sign_path(
hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta
) -> str:
"""Sign a path for temporary access without auth header."""
secret = hass.data.get(DATA_SIGN_SECRET)
@ -44,17 +49,19 @@ def async_sign_path(hass, refresh_token_id, path, expiration):
@callback
def setup_auth(hass, app):
def setup_auth(hass: HomeAssistant, app: Application) -> None:
"""Create auth middleware for the app."""
async def async_validate_auth_header(request):
async def async_validate_auth_header(request: Request) -> bool:
"""
Test authorization header against access token.
Basic auth_type is legacy code, should be removed with api_password.
"""
try:
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(" ", 1)
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION, "").split(
" ", 1
)
except ValueError:
# If no space in authorization header
return False
@ -71,7 +78,7 @@ def setup_auth(hass, app):
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
return True
async def async_validate_signed_request(request):
async def async_validate_signed_request(request: Request) -> bool:
"""Validate a signed request."""
secret = hass.data.get(DATA_SIGN_SECRET)
@ -103,7 +110,9 @@ def setup_auth(hass, app):
return True
@middleware
async def auth_middleware(request, handler):
async def auth_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Authenticate as middleware."""
authenticated = False

View File

@ -2,13 +2,15 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import datetime
from ipaddress import ip_address
import logging
from socket import gethostbyaddr, herror
from typing import Any, Final
from aiohttp.web import middleware
from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol
@ -19,33 +21,33 @@ from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.util import dt as dt_util, yaml
# mypy: allow-untyped-defs, no-check-untyped-defs
from .view import HomeAssistantView
_LOGGER = logging.getLogger(__name__)
_LOGGER: Final = logging.getLogger(__name__)
KEY_BANNED_IPS = "ha_banned_ips"
KEY_FAILED_LOGIN_ATTEMPTS = "ha_failed_login_attempts"
KEY_LOGIN_THRESHOLD = "ha_login_threshold"
KEY_BANNED_IPS: Final = "ha_banned_ips"
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
NOTIFICATION_ID_BAN = "ip-ban"
NOTIFICATION_ID_LOGIN = "http-login"
NOTIFICATION_ID_BAN: Final = "ip-ban"
NOTIFICATION_ID_LOGIN: Final = "http-login"
IP_BANS_FILE = "ip_bans.yaml"
ATTR_BANNED_AT = "banned_at"
IP_BANS_FILE: Final = "ip_bans.yaml"
ATTR_BANNED_AT: Final = "banned_at"
SCHEMA_IP_BAN_ENTRY = vol.Schema(
SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
{vol.Optional("banned_at"): vol.Any(None, cv.datetime)}
)
@callback
def setup_bans(hass, app, login_threshold):
def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
"""Create IP Ban middleware for the app."""
app.middlewares.append(ban_middleware)
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold
async def ban_startup(app):
async def ban_startup(app: Application) -> None:
"""Initialize bans when app starts up."""
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
hass, hass.config.path(IP_BANS_FILE)
@ -55,7 +57,9 @@ def setup_bans(hass, app, login_threshold):
@middleware
async def ban_middleware(request, handler):
async def ban_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""IP Ban middleware."""
if KEY_BANNED_IPS not in request.app:
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
@ -77,10 +81,14 @@ async def ban_middleware(request, handler):
raise
def log_invalid_auth(func):
def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]]
) -> Callable[..., Awaitable[StreamResponse]]:
"""Decorate function to handle invalid auth or failed login attempts."""
async def handle_req(view, request, *args, **kwargs):
async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
) -> StreamResponse:
"""Try to log failed login attempts if response status >= 400."""
resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTP_BAD_REQUEST:
@ -90,7 +98,7 @@ def log_invalid_auth(func):
return handle_req
async def process_wrong_login(request):
async def process_wrong_login(request: Request) -> None:
"""Process a wrong login attempt.
Increase failed login attempts counter for remote IP address.
@ -152,7 +160,7 @@ async def process_wrong_login(request):
)
async def process_success_login(request):
async def process_success_login(request: Request) -> None:
"""Process a success login attempt.
Reset failed login attempts counter for remote IP address.

View File

@ -1,5 +1,7 @@
"""HTTP specific constants."""
KEY_AUTHENTICATED = "ha_authenticated"
KEY_HASS = "hass"
KEY_HASS_USER = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"
from typing import Final
KEY_AUTHENTICATED: Final = "ha_authenticated"
KEY_HASS: Final = "hass"
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"

View File

@ -1,24 +1,33 @@
"""Provide CORS support for the HTTP component."""
from __future__ import annotations
from typing import Final
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource
from aiohttp.web import Application
from aiohttp.web_urldispatcher import (
AbstractResource,
AbstractRoute,
Resource,
ResourceRoute,
StaticResource,
)
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
from homeassistant.core import callback
# mypy: allow-untyped-defs, no-check-untyped-defs
ALLOWED_CORS_HEADERS = [
ALLOWED_CORS_HEADERS: Final[list[str]] = [
ORIGIN,
ACCEPT,
HTTP_HEADER_X_REQUESTED_WITH,
CONTENT_TYPE,
AUTHORIZATION,
]
VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource)
VALID_CORS_TYPES: Final = (Resource, ResourceRoute, StaticResource)
@callback
def setup_cors(app, origins):
def setup_cors(app: Application, origins: list[str]) -> None:
"""Set up CORS."""
# This import should remain here. That way the HTTP integration can always
# be imported by other integrations without it's requirements being installed.
@ -37,9 +46,12 @@ def setup_cors(app, origins):
cors_added = set()
def _allow_cors(route, config=None):
def _allow_cors(
route: AbstractRoute | AbstractResource,
config: dict[str, aiohttp_cors.ResourceOptions] | None = None,
) -> None:
"""Allow CORS on a route."""
if hasattr(route, "resource"):
if isinstance(route, AbstractRoute):
path = route.resource
else:
path = route
@ -47,16 +59,16 @@ def setup_cors(app, origins):
if not isinstance(path, VALID_CORS_TYPES):
return
path = path.canonical
path_str = path.canonical
if path.startswith("/api/hassio_ingress/"):
if path_str.startswith("/api/hassio_ingress/"):
return
if path in cors_added:
if path_str in cors_added:
return
cors.add(route, config)
cors_added.add(path)
cors_added.add(path_str)
app["allow_cors"] = lambda route: _allow_cors(
route,
@ -70,7 +82,7 @@ def setup_cors(app, origins):
if not origins:
return
async def cors_startup(app):
async def cors_startup(app: Application) -> None:
"""Initialize CORS when app starts up."""
for resource in list(app.router.resources()):
_allow_cors(resource)

View File

@ -1,19 +1,20 @@
"""Middleware to handle forwarded data by a reverse proxy."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from ipaddress import ip_address
import logging
from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
from aiohttp.web import HTTPBadRequest, middleware
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
from homeassistant.core import callback
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs
@callback
def async_setup_forwarded(app, trusted_proxies):
def async_setup_forwarded(app: Application, trusted_proxies: list[str]) -> None:
"""Create forwarded middleware for the app.
Process IP addresses, proto and host information in the forwarded for headers.
@ -60,17 +61,20 @@ def async_setup_forwarded(app, trusted_proxies):
"""
@middleware
async def forwarded_middleware(request, handler):
async def forwarded_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Process forwarded data by a reverse proxy."""
overrides = {}
overrides: dict[str, str] = {}
# Handle X-Forwarded-For
forwarded_for_headers = request.headers.getall(X_FORWARDED_FOR, [])
forwarded_for_headers: list[str] = request.headers.getall(X_FORWARDED_FOR, [])
if not forwarded_for_headers:
# No forwarding headers, continue as normal
return await handler(request)
# Ensure the IP of the connected peer is trusted
assert request.transport is not None
connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies):
_LOGGER.warning(
@ -111,7 +115,9 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["remote"] = str(forwarded_for[-1])
# Handle X-Forwarded-Proto
forwarded_proto_headers = request.headers.getall(X_FORWARDED_PROTO, [])
forwarded_proto_headers: list[str] = request.headers.getall(
X_FORWARDED_PROTO, []
)
if forwarded_proto_headers:
if len(forwarded_proto_headers) > 1:
_LOGGER.error(
@ -151,7 +157,7 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["scheme"] = forwarded_proto[forwarded_for_index]
# Handle X-Forwarded-Host
forwarded_host_headers = request.headers.getall(X_FORWARDED_HOST, [])
forwarded_host_headers: list[str] = request.headers.getall(X_FORWARDED_HOST, [])
if forwarded_host_headers:
# Multiple X-Forwarded-Host headers
if len(forwarded_host_headers) > 1:
@ -168,7 +174,7 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["host"] = forwarded_host
# Done, create a new request based on gathered data.
request = request.clone(**overrides)
request = request.clone(**overrides) # type: ignore[arg-type]
return await handler(request)
app.middlewares.append(forwarded_middleware)

View File

@ -1,18 +1,24 @@
"""Middleware to set the request context."""
from __future__ import annotations
from aiohttp.web import middleware
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from aiohttp.web import Application, Request, StreamResponse, middleware
from homeassistant.core import callback
# mypy: allow-untyped-defs
@callback
def setup_request_context(app, context):
def setup_request_context(
app: Application, context: ContextVar[Request | None]
) -> None:
"""Create request context middleware for the app."""
@middleware
async def request_context_middleware(request, handler):
async def request_context_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Request context middleware."""
context.set(request)
return await handler(request)

View File

@ -1,17 +1,19 @@
"""Middleware to add some basic security filtering to requests."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
import logging
import re
from typing import Final
from aiohttp.web import HTTPBadRequest, middleware
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
from homeassistant.core import callback
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs
# fmt: off
FILTERS = re.compile(
FILTERS: Final = re.compile(
r"(?:"
# Common exploits
@ -34,12 +36,14 @@ FILTERS = re.compile(
@callback
def setup_security_filter(app):
def setup_security_filter(app: Application) -> None:
"""Create security filter middleware for the app."""
@middleware
async def security_filter_middleware(request, handler):
"""Process request and block commonly known exploit attempts."""
async def security_filter_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Process request and tblock commonly known exploit attempts."""
if FILTERS.search(request.path):
_LOGGER.warning(
"Filtered a potential harmful request to: %s", request.raw_path

View File

@ -1,21 +1,25 @@
"""Static file handling for HTTP component."""
from __future__ import annotations
from collections.abc import Mapping
from pathlib import Path
from typing import Final
from aiohttp import hdrs
from aiohttp.web import FileResponse
from aiohttp.web import FileResponse, Request, StreamResponse
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
from aiohttp.web_urldispatcher import StaticResource
# mypy: allow-untyped-defs
CACHE_TIME = 31 * 86400 # = 1 month
CACHE_HEADERS = {hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"}
CACHE_TIME: Final = 31 * 86400 # = 1 month
CACHE_HEADERS: Final[Mapping[str, str]] = {
hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"
}
class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers."""
async def _handle(self, request):
async def _handle(self, request: Request) -> StreamResponse:
rel_url = request.match_info["filename"]
try:
filename = Path(rel_url)
@ -42,7 +46,6 @@ class CachingStaticResource(StaticResource):
return FileResponse(
filepath,
chunk_size=self._chunk_size,
# type ignore: https://github.com/aio-libs/aiohttp/pull/3976
headers=CACHE_HEADERS, # type: ignore
headers=CACHE_HEADERS,
)
raise HTTPNotFound

View File

@ -2,9 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
import json
import logging
from typing import Any, Callable
from typing import Any
from aiohttp import web
from aiohttp.typedefs import LooseHeaders
@ -13,6 +14,7 @@ from aiohttp.web_exceptions import (
HTTPInternalServerError,
HTTPUnauthorized,
)
from aiohttp.web_urldispatcher import AbstractRoute
import voluptuous as vol
from homeassistant import exceptions
@ -81,7 +83,7 @@ class HomeAssistantView:
"""Register the view with a router."""
assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls
routes = []
routes: list[AbstractRoute] = []
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
handler = getattr(self, method, None)
@ -101,7 +103,9 @@ class HomeAssistantView:
app["allow_cors"](route)
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
def request_handler_factory(
view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(
handler

View File

@ -23,7 +23,7 @@ class HomeAssistantTCPSite(web.BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
def __init__( # noqa: D107
def __init__(
self,
runner: web.BaseRunner,
host: None | str | list[str],
@ -35,6 +35,7 @@ class HomeAssistantTCPSite(web.BaseSite):
reuse_address: bool | None = None,
reuse_port: bool | None = None,
) -> None:
"""Initialize HomeAssistantTCPSite."""
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
@ -47,12 +48,14 @@ class HomeAssistantTCPSite(web.BaseSite):
self._reuse_port = reuse_port
@property
def name(self) -> str: # noqa: D102
def name(self) -> str:
"""Return server URL."""
scheme = "https" if self._ssl_context else "http"
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
return str(URL.build(scheme=scheme, host=host, port=self._port))
async def start(self) -> None: # noqa: D102
async def start(self) -> None:
"""Start server."""
await super().start()
loop = asyncio.get_running_loop()
server = self._runner.server

View File

@ -593,7 +593,7 @@ SERVICE_TOGGLE_COVER_TILT = "toggle_cover_tilt"
SERVICE_SELECT_OPTION = "select_option"
# #### API / REMOTE ####
SERVER_PORT = 8123
SERVER_PORT: Final = 8123
URL_ROOT = "/"
URL_API = "/api/"

View File

@ -334,7 +334,7 @@ def async_register_implementation(
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
DATA_VIEW_REGISTERED, False
):
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
hass.http.register_view(OAuth2AuthorizeCallbackView())
hass.data[DATA_VIEW_REGISTERED] = True
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})