Add missing type hints in http component (#50411)
parent
85f758380a
commit
ce15f28642
|
@ -6,16 +6,18 @@ from ipaddress import ip_network
|
|||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Final, Optional, TypedDict, cast
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPMovedPermanently
|
||||
from aiohttp.typedefs import StrOrURL
|
||||
from aiohttp.web_exceptions import HTTPMovedPermanently, HTTPRedirection
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
|
||||
from homeassistant.core import Event, HomeAssistant
|
||||
from homeassistant.helpers import storage
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.setup import async_start_setup, async_when_setup_or_start
|
||||
import homeassistant.util as hass_util
|
||||
|
@ -29,44 +31,42 @@ from .forwarded import async_setup_forwarded
|
|||
from .request_context import setup_request_context
|
||||
from .security_filter import setup_security_filter
|
||||
from .static import CACHE_HEADERS, CachingStaticResource
|
||||
from .view import HomeAssistantView # noqa: F401
|
||||
from .view import HomeAssistantView
|
||||
from .web_runner import HomeAssistantTCPSite
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
DOMAIN: Final = "http"
|
||||
|
||||
DOMAIN = "http"
|
||||
CONF_SERVER_HOST: Final = "server_host"
|
||||
CONF_SERVER_PORT: Final = "server_port"
|
||||
CONF_BASE_URL: Final = "base_url"
|
||||
CONF_SSL_CERTIFICATE: Final = "ssl_certificate"
|
||||
CONF_SSL_PEER_CERTIFICATE: Final = "ssl_peer_certificate"
|
||||
CONF_SSL_KEY: Final = "ssl_key"
|
||||
CONF_CORS_ORIGINS: Final = "cors_allowed_origins"
|
||||
CONF_USE_X_FORWARDED_FOR: Final = "use_x_forwarded_for"
|
||||
CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
|
||||
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
|
||||
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
|
||||
CONF_SSL_PROFILE: Final = "ssl_profile"
|
||||
|
||||
CONF_SERVER_HOST = "server_host"
|
||||
CONF_SERVER_PORT = "server_port"
|
||||
CONF_BASE_URL = "base_url"
|
||||
CONF_SSL_CERTIFICATE = "ssl_certificate"
|
||||
CONF_SSL_PEER_CERTIFICATE = "ssl_peer_certificate"
|
||||
CONF_SSL_KEY = "ssl_key"
|
||||
CONF_CORS_ORIGINS = "cors_allowed_origins"
|
||||
CONF_USE_X_FORWARDED_FOR = "use_x_forwarded_for"
|
||||
CONF_TRUSTED_PROXIES = "trusted_proxies"
|
||||
CONF_LOGIN_ATTEMPTS_THRESHOLD = "login_attempts_threshold"
|
||||
CONF_IP_BAN_ENABLED = "ip_ban_enabled"
|
||||
CONF_SSL_PROFILE = "ssl_profile"
|
||||
SSL_MODERN: Final = "modern"
|
||||
SSL_INTERMEDIATE: Final = "intermediate"
|
||||
|
||||
SSL_MODERN = "modern"
|
||||
SSL_INTERMEDIATE = "intermediate"
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_DEVELOPMENT = "0"
|
||||
DEFAULT_DEVELOPMENT: Final = "0"
|
||||
# Cast to be able to load custom cards.
|
||||
# My to be able to check url and version info.
|
||||
DEFAULT_CORS = ["https://cast.home-assistant.io"]
|
||||
NO_LOGIN_ATTEMPT_THRESHOLD = -1
|
||||
DEFAULT_CORS: Final[list[str]] = ["https://cast.home-assistant.io"]
|
||||
NO_LOGIN_ATTEMPT_THRESHOLD: Final = -1
|
||||
|
||||
MAX_CLIENT_SIZE: int = 1024 ** 2 * 16
|
||||
MAX_CLIENT_SIZE: Final = 1024 ** 2 * 16
|
||||
|
||||
STORAGE_KEY = DOMAIN
|
||||
STORAGE_VERSION = 1
|
||||
SAVE_DELAY = 180
|
||||
STORAGE_KEY: Final = DOMAIN
|
||||
STORAGE_VERSION: Final = 1
|
||||
SAVE_DELAY: Final = 180
|
||||
|
||||
HTTP_SCHEMA = vol.All(
|
||||
HTTP_SCHEMA: Final = vol.All(
|
||||
cv.deprecated(CONF_BASE_URL),
|
||||
vol.Schema(
|
||||
{
|
||||
|
@ -96,7 +96,24 @@ HTTP_SCHEMA = vol.All(
|
|||
),
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
|
||||
CONFIG_SCHEMA: Final = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
class ConfData(TypedDict, total=False):
|
||||
"""Typed dict for config data."""
|
||||
|
||||
server_host: list[str]
|
||||
server_port: int
|
||||
base_url: str
|
||||
ssl_certificate: str
|
||||
ssl_peer_certificate: str
|
||||
ssl_key: str
|
||||
cors_allowed_origins: list[str]
|
||||
use_x_forwarded_for: bool
|
||||
trusted_proxies: list[str]
|
||||
login_attempts_threshold: int
|
||||
ip_ban_enabled: bool
|
||||
ssl_profile: str
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -113,8 +130,8 @@ class ApiConfig:
|
|||
self,
|
||||
local_ip: str,
|
||||
host: str,
|
||||
port: int | None = SERVER_PORT,
|
||||
use_ssl: bool = False,
|
||||
port: int,
|
||||
use_ssl: bool,
|
||||
) -> None:
|
||||
"""Initialize a new API config object."""
|
||||
self.local_ip = local_ip
|
||||
|
@ -123,12 +140,12 @@ class ApiConfig:
|
|||
self.use_ssl = use_ssl
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the HTTP API and debug interface."""
|
||||
conf = config.get(DOMAIN)
|
||||
conf: ConfData | None = config.get(DOMAIN)
|
||||
|
||||
if conf is None:
|
||||
conf = HTTP_SCHEMA({})
|
||||
conf = cast(ConfData, HTTP_SCHEMA({}))
|
||||
|
||||
server_host = conf.get(CONF_SERVER_HOST)
|
||||
server_port = conf[CONF_SERVER_PORT]
|
||||
|
@ -137,7 +154,7 @@ async def async_setup(hass, config):
|
|||
ssl_key = conf.get(CONF_SSL_KEY)
|
||||
cors_origins = conf[CONF_CORS_ORIGINS]
|
||||
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
|
||||
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES, [])
|
||||
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES) or []
|
||||
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
|
||||
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
|
||||
ssl_profile = conf[CONF_SSL_PROFILE]
|
||||
|
@ -165,6 +182,8 @@ async def async_setup(hass, config):
|
|||
"""Start the server."""
|
||||
with async_start_setup(hass, ["http"]):
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
||||
# We already checked it's not None.
|
||||
assert conf is not None
|
||||
await start_http_server_and_save_config(hass, dict(conf), server)
|
||||
|
||||
async_when_setup_or_start(hass, "frontend", start_server)
|
||||
|
@ -190,19 +209,19 @@ class HomeAssistantHTTP:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
hass,
|
||||
ssl_certificate,
|
||||
ssl_peer_certificate,
|
||||
ssl_key,
|
||||
server_host,
|
||||
server_port,
|
||||
cors_origins,
|
||||
use_x_forwarded_for,
|
||||
trusted_proxies,
|
||||
login_threshold,
|
||||
is_ban_enabled,
|
||||
ssl_profile,
|
||||
):
|
||||
hass: HomeAssistant,
|
||||
ssl_certificate: str | None,
|
||||
ssl_peer_certificate: str | None,
|
||||
ssl_key: str | None,
|
||||
server_host: list[str] | None,
|
||||
server_port: int,
|
||||
cors_origins: list[str],
|
||||
use_x_forwarded_for: bool,
|
||||
trusted_proxies: list[str],
|
||||
login_threshold: int,
|
||||
is_ban_enabled: bool,
|
||||
ssl_profile: str,
|
||||
) -> None:
|
||||
"""Initialize the HTTP Home Assistant server."""
|
||||
app = self.app = web.Application(
|
||||
middlewares=[], client_max_size=MAX_CLIENT_SIZE
|
||||
|
@ -237,10 +256,10 @@ class HomeAssistantHTTP:
|
|||
self.is_ban_enabled = is_ban_enabled
|
||||
self.ssl_profile = ssl_profile
|
||||
self._handler = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self.runner: web.AppRunner | None = None
|
||||
self.site: HomeAssistantTCPSite | None = None
|
||||
|
||||
def register_view(self, view):
|
||||
def register_view(self, view: HomeAssistantView) -> None:
|
||||
"""Register a view with the WSGI server.
|
||||
|
||||
The view argument must be a class that inherits from HomeAssistantView.
|
||||
|
@ -261,7 +280,13 @@ class HomeAssistantHTTP:
|
|||
|
||||
view.register(self.app, self.app.router)
|
||||
|
||||
def register_redirect(self, url, redirect_to, *, redirect_exc=HTTPMovedPermanently):
|
||||
def register_redirect(
|
||||
self,
|
||||
url: str,
|
||||
redirect_to: StrOrURL,
|
||||
*,
|
||||
redirect_exc: type[HTTPRedirection] = HTTPMovedPermanently,
|
||||
) -> None:
|
||||
"""Register a redirect with the server.
|
||||
|
||||
If given this must be either a string or callable. In case of a
|
||||
|
@ -271,38 +296,39 @@ class HomeAssistantHTTP:
|
|||
rule syntax.
|
||||
"""
|
||||
|
||||
async def redirect(request):
|
||||
async def redirect(request: web.Request) -> web.StreamResponse:
|
||||
"""Redirect to location."""
|
||||
raise redirect_exc(redirect_to)
|
||||
# Should be instance of aiohttp.web_exceptions._HTTPMove.
|
||||
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
|
||||
|
||||
self.app.router.add_route("GET", url, redirect)
|
||||
|
||||
def register_static_path(self, url_path, path, cache_headers=True):
|
||||
def register_static_path(
|
||||
self, url_path: str, path: str, cache_headers: bool = True
|
||||
) -> web.FileResponse | None:
|
||||
"""Register a folder or file to serve as a static path."""
|
||||
if os.path.isdir(path):
|
||||
if cache_headers:
|
||||
resource = CachingStaticResource
|
||||
resource: type[
|
||||
CachingStaticResource | web.StaticResource
|
||||
] = CachingStaticResource
|
||||
else:
|
||||
resource = web.StaticResource
|
||||
self.app.router.register_resource(resource(url_path, path))
|
||||
return
|
||||
return None
|
||||
|
||||
if cache_headers:
|
||||
|
||||
async def serve_file(request):
|
||||
"""Serve file from disk."""
|
||||
async def serve_file(request: web.Request) -> web.FileResponse:
|
||||
"""Serve file from disk."""
|
||||
if cache_headers:
|
||||
return web.FileResponse(path, headers=CACHE_HEADERS)
|
||||
|
||||
else:
|
||||
|
||||
async def serve_file(request):
|
||||
"""Serve file from disk."""
|
||||
return web.FileResponse(path)
|
||||
return web.FileResponse(path)
|
||||
|
||||
self.app.router.add_route("GET", url_path, serve_file)
|
||||
return None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Start the aiohttp server."""
|
||||
context: ssl.SSLContext | None
|
||||
if self.ssl_certificate:
|
||||
try:
|
||||
if self.ssl_profile == SSL_INTERMEDIATE:
|
||||
|
@ -334,7 +360,7 @@ class HomeAssistantHTTP:
|
|||
# This will now raise a RunTimeError.
|
||||
# To work around this we now prevent the router from getting frozen
|
||||
# pylint: disable=protected-access
|
||||
self.app._router.freeze = lambda: None
|
||||
self.app._router.freeze = lambda: None # type: ignore[assignment]
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
@ -351,17 +377,19 @@ class HomeAssistantHTTP:
|
|||
|
||||
_LOGGER.info("Now listening on port %d", self.server_port)
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the aiohttp server."""
|
||||
await self.site.stop()
|
||||
await self.runner.cleanup()
|
||||
if self.site is not None:
|
||||
await self.site.stop()
|
||||
if self.runner is not None:
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
async def start_http_server_and_save_config(
|
||||
hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP
|
||||
) -> None:
|
||||
"""Startup the http server and save the config."""
|
||||
await server.start() # type: ignore
|
||||
await server.start()
|
||||
|
||||
# If we are set up successful, we store the HTTP settings for safe mode.
|
||||
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
|
|
|
@ -1,28 +1,33 @@
|
|||
"""Authentication for HTTP component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Final
|
||||
from urllib.parse import unquote
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import middleware
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
import jwt
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DATA_API_PASSWORD = "api_password"
|
||||
DATA_SIGN_SECRET = "http.auth.sign_secret"
|
||||
SIGN_QUERY_PARAM = "authSig"
|
||||
DATA_API_PASSWORD: Final = "api_password"
|
||||
DATA_SIGN_SECRET: Final = "http.auth.sign_secret"
|
||||
SIGN_QUERY_PARAM: Final = "authSig"
|
||||
|
||||
|
||||
@callback
|
||||
def async_sign_path(hass, refresh_token_id, path, expiration):
|
||||
def async_sign_path(
|
||||
hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta
|
||||
) -> str:
|
||||
"""Sign a path for temporary access without auth header."""
|
||||
secret = hass.data.get(DATA_SIGN_SECRET)
|
||||
|
||||
|
@ -44,17 +49,19 @@ def async_sign_path(hass, refresh_token_id, path, expiration):
|
|||
|
||||
|
||||
@callback
|
||||
def setup_auth(hass, app):
|
||||
def setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||
"""Create auth middleware for the app."""
|
||||
|
||||
async def async_validate_auth_header(request):
|
||||
async def async_validate_auth_header(request: Request) -> bool:
|
||||
"""
|
||||
Test authorization header against access token.
|
||||
|
||||
Basic auth_type is legacy code, should be removed with api_password.
|
||||
"""
|
||||
try:
|
||||
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(" ", 1)
|
||||
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION, "").split(
|
||||
" ", 1
|
||||
)
|
||||
except ValueError:
|
||||
# If no space in authorization header
|
||||
return False
|
||||
|
@ -71,7 +78,7 @@ def setup_auth(hass, app):
|
|||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||
return True
|
||||
|
||||
async def async_validate_signed_request(request):
|
||||
async def async_validate_signed_request(request: Request) -> bool:
|
||||
"""Validate a signed request."""
|
||||
secret = hass.data.get(DATA_SIGN_SECRET)
|
||||
|
||||
|
@ -103,7 +110,9 @@ def setup_auth(hass, app):
|
|||
return True
|
||||
|
||||
@middleware
|
||||
async def auth_middleware(request, handler):
|
||||
async def auth_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Authenticate as middleware."""
|
||||
authenticated = False
|
||||
|
||||
|
|
|
@ -2,13 +2,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
from socket import gethostbyaddr, herror
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp.web import middleware
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -19,33 +21,33 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util import dt as dt_util, yaml
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
KEY_BANNED_IPS = "ha_banned_ips"
|
||||
KEY_FAILED_LOGIN_ATTEMPTS = "ha_failed_login_attempts"
|
||||
KEY_LOGIN_THRESHOLD = "ha_login_threshold"
|
||||
KEY_BANNED_IPS: Final = "ha_banned_ips"
|
||||
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
|
||||
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
|
||||
|
||||
NOTIFICATION_ID_BAN = "ip-ban"
|
||||
NOTIFICATION_ID_LOGIN = "http-login"
|
||||
NOTIFICATION_ID_BAN: Final = "ip-ban"
|
||||
NOTIFICATION_ID_LOGIN: Final = "http-login"
|
||||
|
||||
IP_BANS_FILE = "ip_bans.yaml"
|
||||
ATTR_BANNED_AT = "banned_at"
|
||||
IP_BANS_FILE: Final = "ip_bans.yaml"
|
||||
ATTR_BANNED_AT: Final = "banned_at"
|
||||
|
||||
SCHEMA_IP_BAN_ENTRY = vol.Schema(
|
||||
SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
|
||||
{vol.Optional("banned_at"): vol.Any(None, cv.datetime)}
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def setup_bans(hass, app, login_threshold):
|
||||
def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
|
||||
"""Create IP Ban middleware for the app."""
|
||||
app.middlewares.append(ban_middleware)
|
||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||
|
||||
async def ban_startup(app):
|
||||
async def ban_startup(app: Application) -> None:
|
||||
"""Initialize bans when app starts up."""
|
||||
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
|
||||
hass, hass.config.path(IP_BANS_FILE)
|
||||
|
@ -55,7 +57,9 @@ def setup_bans(hass, app, login_threshold):
|
|||
|
||||
|
||||
@middleware
|
||||
async def ban_middleware(request, handler):
|
||||
async def ban_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""IP Ban middleware."""
|
||||
if KEY_BANNED_IPS not in request.app:
|
||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||
|
@ -77,10 +81,14 @@ async def ban_middleware(request, handler):
|
|||
raise
|
||||
|
||||
|
||||
def log_invalid_auth(func):
|
||||
def log_invalid_auth(
|
||||
func: Callable[..., Awaitable[StreamResponse]]
|
||||
) -> Callable[..., Awaitable[StreamResponse]]:
|
||||
"""Decorate function to handle invalid auth or failed login attempts."""
|
||||
|
||||
async def handle_req(view, request, *args, **kwargs):
|
||||
async def handle_req(
|
||||
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
|
||||
) -> StreamResponse:
|
||||
"""Try to log failed login attempts if response status >= 400."""
|
||||
resp = await func(view, request, *args, **kwargs)
|
||||
if resp.status >= HTTP_BAD_REQUEST:
|
||||
|
@ -90,7 +98,7 @@ def log_invalid_auth(func):
|
|||
return handle_req
|
||||
|
||||
|
||||
async def process_wrong_login(request):
|
||||
async def process_wrong_login(request: Request) -> None:
|
||||
"""Process a wrong login attempt.
|
||||
|
||||
Increase failed login attempts counter for remote IP address.
|
||||
|
@ -152,7 +160,7 @@ async def process_wrong_login(request):
|
|||
)
|
||||
|
||||
|
||||
async def process_success_login(request):
|
||||
async def process_success_login(request: Request) -> None:
|
||||
"""Process a success login attempt.
|
||||
|
||||
Reset failed login attempts counter for remote IP address.
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""HTTP specific constants."""
|
||||
KEY_AUTHENTICATED = "ha_authenticated"
|
||||
KEY_HASS = "hass"
|
||||
KEY_HASS_USER = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"
|
||||
from typing import Final
|
||||
|
||||
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||
KEY_HASS: Final = "hass"
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
|
|
@ -1,24 +1,33 @@
|
|||
"""Provide CORS support for the HTTP component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final
|
||||
|
||||
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
|
||||
from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource
|
||||
from aiohttp.web import Application
|
||||
from aiohttp.web_urldispatcher import (
|
||||
AbstractResource,
|
||||
AbstractRoute,
|
||||
Resource,
|
||||
ResourceRoute,
|
||||
StaticResource,
|
||||
)
|
||||
|
||||
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
|
||||
from homeassistant.core import callback
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
ALLOWED_CORS_HEADERS = [
|
||||
ALLOWED_CORS_HEADERS: Final[list[str]] = [
|
||||
ORIGIN,
|
||||
ACCEPT,
|
||||
HTTP_HEADER_X_REQUESTED_WITH,
|
||||
CONTENT_TYPE,
|
||||
AUTHORIZATION,
|
||||
]
|
||||
VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource)
|
||||
VALID_CORS_TYPES: Final = (Resource, ResourceRoute, StaticResource)
|
||||
|
||||
|
||||
@callback
|
||||
def setup_cors(app, origins):
|
||||
def setup_cors(app: Application, origins: list[str]) -> None:
|
||||
"""Set up CORS."""
|
||||
# This import should remain here. That way the HTTP integration can always
|
||||
# be imported by other integrations without it's requirements being installed.
|
||||
|
@ -37,9 +46,12 @@ def setup_cors(app, origins):
|
|||
|
||||
cors_added = set()
|
||||
|
||||
def _allow_cors(route, config=None):
|
||||
def _allow_cors(
|
||||
route: AbstractRoute | AbstractResource,
|
||||
config: dict[str, aiohttp_cors.ResourceOptions] | None = None,
|
||||
) -> None:
|
||||
"""Allow CORS on a route."""
|
||||
if hasattr(route, "resource"):
|
||||
if isinstance(route, AbstractRoute):
|
||||
path = route.resource
|
||||
else:
|
||||
path = route
|
||||
|
@ -47,16 +59,16 @@ def setup_cors(app, origins):
|
|||
if not isinstance(path, VALID_CORS_TYPES):
|
||||
return
|
||||
|
||||
path = path.canonical
|
||||
path_str = path.canonical
|
||||
|
||||
if path.startswith("/api/hassio_ingress/"):
|
||||
if path_str.startswith("/api/hassio_ingress/"):
|
||||
return
|
||||
|
||||
if path in cors_added:
|
||||
if path_str in cors_added:
|
||||
return
|
||||
|
||||
cors.add(route, config)
|
||||
cors_added.add(path)
|
||||
cors_added.add(path_str)
|
||||
|
||||
app["allow_cors"] = lambda route: _allow_cors(
|
||||
route,
|
||||
|
@ -70,7 +82,7 @@ def setup_cors(app, origins):
|
|||
if not origins:
|
||||
return
|
||||
|
||||
async def cors_startup(app):
|
||||
async def cors_startup(app: Application) -> None:
|
||||
"""Initialize CORS when app starts up."""
|
||||
for resource in list(app.router.resources()):
|
||||
_allow_cors(resource)
|
||||
|
|
|
@ -1,19 +1,20 @@
|
|||
"""Middleware to handle forwarded data by a reverse proxy."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
|
||||
from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
|
||||
from aiohttp.web import HTTPBadRequest, middleware
|
||||
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_forwarded(app, trusted_proxies):
|
||||
def async_setup_forwarded(app: Application, trusted_proxies: list[str]) -> None:
|
||||
"""Create forwarded middleware for the app.
|
||||
|
||||
Process IP addresses, proto and host information in the forwarded for headers.
|
||||
|
@ -60,17 +61,20 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||
"""
|
||||
|
||||
@middleware
|
||||
async def forwarded_middleware(request, handler):
|
||||
async def forwarded_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Process forwarded data by a reverse proxy."""
|
||||
overrides = {}
|
||||
overrides: dict[str, str] = {}
|
||||
|
||||
# Handle X-Forwarded-For
|
||||
forwarded_for_headers = request.headers.getall(X_FORWARDED_FOR, [])
|
||||
forwarded_for_headers: list[str] = request.headers.getall(X_FORWARDED_FOR, [])
|
||||
if not forwarded_for_headers:
|
||||
# No forwarding headers, continue as normal
|
||||
return await handler(request)
|
||||
|
||||
# Ensure the IP of the connected peer is trusted
|
||||
assert request.transport is not None
|
||||
connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
|
||||
if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies):
|
||||
_LOGGER.warning(
|
||||
|
@ -111,7 +115,9 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||
overrides["remote"] = str(forwarded_for[-1])
|
||||
|
||||
# Handle X-Forwarded-Proto
|
||||
forwarded_proto_headers = request.headers.getall(X_FORWARDED_PROTO, [])
|
||||
forwarded_proto_headers: list[str] = request.headers.getall(
|
||||
X_FORWARDED_PROTO, []
|
||||
)
|
||||
if forwarded_proto_headers:
|
||||
if len(forwarded_proto_headers) > 1:
|
||||
_LOGGER.error(
|
||||
|
@ -151,7 +157,7 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||
overrides["scheme"] = forwarded_proto[forwarded_for_index]
|
||||
|
||||
# Handle X-Forwarded-Host
|
||||
forwarded_host_headers = request.headers.getall(X_FORWARDED_HOST, [])
|
||||
forwarded_host_headers: list[str] = request.headers.getall(X_FORWARDED_HOST, [])
|
||||
if forwarded_host_headers:
|
||||
# Multiple X-Forwarded-Host headers
|
||||
if len(forwarded_host_headers) > 1:
|
||||
|
@ -168,7 +174,7 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||
overrides["host"] = forwarded_host
|
||||
|
||||
# Done, create a new request based on gathered data.
|
||||
request = request.clone(**overrides)
|
||||
request = request.clone(**overrides) # type: ignore[arg-type]
|
||||
return await handler(request)
|
||||
|
||||
app.middlewares.append(forwarded_middleware)
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
"""Middleware to set the request context."""
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp.web import middleware
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
|
||||
@callback
|
||||
def setup_request_context(app, context):
|
||||
def setup_request_context(
|
||||
app: Application, context: ContextVar[Request | None]
|
||||
) -> None:
|
||||
"""Create request context middleware for the app."""
|
||||
|
||||
@middleware
|
||||
async def request_context_middleware(request, handler):
|
||||
async def request_context_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Request context middleware."""
|
||||
context.set(request)
|
||||
return await handler(request)
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
"""Middleware to add some basic security filtering to requests."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from aiohttp.web import HTTPBadRequest, middleware
|
||||
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
# fmt: off
|
||||
FILTERS = re.compile(
|
||||
FILTERS: Final = re.compile(
|
||||
r"(?:"
|
||||
|
||||
# Common exploits
|
||||
|
@ -34,12 +36,14 @@ FILTERS = re.compile(
|
|||
|
||||
|
||||
@callback
|
||||
def setup_security_filter(app):
|
||||
def setup_security_filter(app: Application) -> None:
|
||||
"""Create security filter middleware for the app."""
|
||||
|
||||
@middleware
|
||||
async def security_filter_middleware(request, handler):
|
||||
"""Process request and block commonly known exploit attempts."""
|
||||
async def security_filter_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Process request and tblock commonly known exploit attempts."""
|
||||
if FILTERS.search(request.path):
|
||||
_LOGGER.warning(
|
||||
"Filtered a potential harmful request to: %s", request.raw_path
|
||||
|
|
|
@ -1,21 +1,25 @@
|
|||
"""Static file handling for HTTP component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import FileResponse
|
||||
from aiohttp.web import FileResponse, Request, StreamResponse
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
|
||||
from aiohttp.web_urldispatcher import StaticResource
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
CACHE_TIME = 31 * 86400 # = 1 month
|
||||
CACHE_HEADERS = {hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"}
|
||||
CACHE_TIME: Final = 31 * 86400 # = 1 month
|
||||
CACHE_HEADERS: Final[Mapping[str, str]] = {
|
||||
hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"
|
||||
}
|
||||
|
||||
|
||||
class CachingStaticResource(StaticResource):
|
||||
"""Static Resource handler that will add cache headers."""
|
||||
|
||||
async def _handle(self, request):
|
||||
async def _handle(self, request: Request) -> StreamResponse:
|
||||
rel_url = request.match_info["filename"]
|
||||
try:
|
||||
filename = Path(rel_url)
|
||||
|
@ -42,7 +46,6 @@ class CachingStaticResource(StaticResource):
|
|||
return FileResponse(
|
||||
filepath,
|
||||
chunk_size=self._chunk_size,
|
||||
# type ignore: https://github.com/aio-libs/aiohttp/pull/3976
|
||||
headers=CACHE_HEADERS, # type: ignore
|
||||
headers=CACHE_HEADERS,
|
||||
)
|
||||
raise HTTPNotFound
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
|
@ -13,6 +14,7 @@ from aiohttp.web_exceptions import (
|
|||
HTTPInternalServerError,
|
||||
HTTPUnauthorized,
|
||||
)
|
||||
from aiohttp.web_urldispatcher import AbstractRoute
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
|
@ -81,7 +83,7 @@ class HomeAssistantView:
|
|||
"""Register the view with a router."""
|
||||
assert self.url is not None, "No url set for view"
|
||||
urls = [self.url] + self.extra_urls
|
||||
routes = []
|
||||
routes: list[AbstractRoute] = []
|
||||
|
||||
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
||||
handler = getattr(self, method, None)
|
||||
|
@ -101,7 +103,9 @@ class HomeAssistantView:
|
|||
app["allow_cors"](route)
|
||||
|
||||
|
||||
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
|
||||
def request_handler_factory(
|
||||
view: HomeAssistantView, handler: Callable
|
||||
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
||||
"""Wrap the handler classes."""
|
||||
assert asyncio.iscoroutinefunction(handler) or is_callback(
|
||||
handler
|
||||
|
|
|
@ -23,7 +23,7 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||
|
||||
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
|
||||
|
||||
def __init__( # noqa: D107
|
||||
def __init__(
|
||||
self,
|
||||
runner: web.BaseRunner,
|
||||
host: None | str | list[str],
|
||||
|
@ -35,6 +35,7 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||
reuse_address: bool | None = None,
|
||||
reuse_port: bool | None = None,
|
||||
) -> None:
|
||||
"""Initialize HomeAssistantTCPSite."""
|
||||
super().__init__(
|
||||
runner,
|
||||
shutdown_timeout=shutdown_timeout,
|
||||
|
@ -47,12 +48,14 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||
self._reuse_port = reuse_port
|
||||
|
||||
@property
|
||||
def name(self) -> str: # noqa: D102
|
||||
def name(self) -> str:
|
||||
"""Return server URL."""
|
||||
scheme = "https" if self._ssl_context else "http"
|
||||
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
|
||||
return str(URL.build(scheme=scheme, host=host, port=self._port))
|
||||
|
||||
async def start(self) -> None: # noqa: D102
|
||||
async def start(self) -> None:
|
||||
"""Start server."""
|
||||
await super().start()
|
||||
loop = asyncio.get_running_loop()
|
||||
server = self._runner.server
|
||||
|
|
|
@ -593,7 +593,7 @@ SERVICE_TOGGLE_COVER_TILT = "toggle_cover_tilt"
|
|||
SERVICE_SELECT_OPTION = "select_option"
|
||||
|
||||
# #### API / REMOTE ####
|
||||
SERVER_PORT = 8123
|
||||
SERVER_PORT: Final = 8123
|
||||
|
||||
URL_ROOT = "/"
|
||||
URL_API = "/api/"
|
||||
|
|
|
@ -334,7 +334,7 @@ def async_register_implementation(
|
|||
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
|
||||
DATA_VIEW_REGISTERED, False
|
||||
):
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView())
|
||||
hass.data[DATA_VIEW_REGISTERED] = True
|
||||
|
||||
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
||||
|
|
Loading…
Reference in New Issue