diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index b24da92afdd..6cc9d94c7a6 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView): @RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA)) @log_invalid_auth - async def post(self, request, flow_id, data): + async def post(self, request, data, flow_id): """Handle progressing a login flow request.""" client_id = data.pop("client_id") diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index ee8324b2791..d2f5f9d8ba5 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -2,17 +2,18 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from contextlib import suppress from datetime import datetime from http import HTTPStatus from ipaddress import IPv4Address, IPv6Address, ip_address import logging from socket import gethostbyaddr, herror -from typing import Any, Final +from typing import Any, Final, TypeVar -from aiohttp.web import Application, Request, StreamResponse, middleware +from aiohttp.web import Application, Request, Response, StreamResponse, middleware from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized +from typing_extensions import Concatenate, ParamSpec import voluptuous as vol from homeassistant.components import persistent_notification @@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml from .view import HomeAssistantView +_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView) +_P = ParamSpec("_P") + _LOGGER: Final = logging.getLogger(__name__) KEY_BAN_MANAGER: Final = "ha_banned_ips_manager" @@ -82,13 +86,13 @@ async def ban_middleware( def log_invalid_auth( - func: Callable[..., Awaitable[StreamResponse]] -) -> Callable[..., Awaitable[StreamResponse]]: + func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]] +) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]: """Decorate function to handle invalid auth or failed login attempts.""" async def handle_req( - view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any - ) -> StreamResponse: + view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs + ) -> Response: """Try to log failed login attempts if response status >= BAD_REQUEST.""" resp = await func(view, request, *args, **kwargs) if resp.status >= HTTPStatus.BAD_REQUEST: diff --git a/homeassistant/components/http/data_validator.py b/homeassistant/components/http/data_validator.py index cc661d43fd8..6647a6436c5 100644 --- a/homeassistant/components/http/data_validator.py +++ b/homeassistant/components/http/data_validator.py @@ -1,17 +1,21 @@ """Decorator for view methods to help with data validation.""" from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from functools import wraps from http import HTTPStatus import logging -from typing import Any +from typing import Any, TypeVar from aiohttp import web +from typing_extensions import Concatenate, ParamSpec import voluptuous as vol from .view import HomeAssistantView +_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView) +_P = ParamSpec("_P") + _LOGGER = logging.getLogger(__name__) @@ -33,33 +37,40 @@ class RequestDataValidator: self._allow_empty = allow_empty def __call__( - self, method: Callable[..., Awaitable[web.StreamResponse]] - ) -> Callable: + self, + method: Callable[ + Concatenate[_HassViewT, web.Request, dict[str, Any], _P], + Awaitable[web.Response], + ], + ) -> Callable[ + Concatenate[_HassViewT, web.Request, _P], + Coroutine[Any, Any, web.Response], + ]: """Decorate a function.""" @wraps(method) async def wrapper( - view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any - ) -> web.StreamResponse: + view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs + ) -> web.Response: """Wrap a request handler with data validation.""" - data = None + raw_data = None try: - data = await request.json() + raw_data = await request.json() except ValueError: if not self._allow_empty or (await request.content.read()) != b"": _LOGGER.error("Invalid JSON received") return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST) - data = {} + raw_data = {} try: - kwargs["data"] = self._schema(data) + data: dict[str, Any] = self._schema(raw_data) except vol.Invalid as err: _LOGGER.error("Data does not match schema: %s", err) return view.json_message( f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST ) - result = await method(view, request, *args, **kwargs) + result = await method(view, request, data, *args, **kwargs) return result return wrapper diff --git a/homeassistant/components/repairs/websocket_api.py b/homeassistant/components/repairs/websocket_api.py index b6a71773273..2e9fcc5f8e4 100644 --- a/homeassistant/components/repairs/websocket_api.py +++ b/homeassistant/components/repairs/websocket_api.py @@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView): result = self._prepare_result_json(result) - return self.json(result) # pylint: disable=arguments-differ + return self.json(result) class RepairsFlowResourceView(FlowManagerResourceView): @@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView): raise Unauthorized(permission=POLICY_EDIT) # pylint: disable=no-value-for-parameter - return await super().post(request, flow_id) # type: ignore[no-any-return] + return await super().post(request, flow_id) diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index 444876a7674..428a62f0c9d 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView): @RequestDataValidator(vol.Schema(dict), allow_empty=True) async def post( - self, request: web.Request, flow_id: str, data: dict[str, Any] + self, request: web.Request, data: dict[str, Any], flow_id: str ) -> web.Response: """Handle a POST request.""" try: