Improve http decorator typing (#75541)
parent
1d7d2875e1
commit
b1ed1543c8
|
@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
|
|||
|
||||
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
|
||||
@log_invalid_auth
|
||||
async def post(self, request, flow_id, data):
|
||||
async def post(self, request, data, flow_id):
|
||||
"""Handle progressing a login flow request."""
|
||||
client_id = data.pop("client_id")
|
||||
|
||||
|
|
|
@ -2,17 +2,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
import logging
|
||||
from socket import gethostbyaddr, herror
|
||||
from typing import Any, Final
|
||||
from typing import Any, Final, TypeVar
|
||||
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import persistent_notification
|
||||
|
@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
|
|||
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
|
||||
|
@ -82,13 +86,13 @@ async def ban_middleware(
|
|||
|
||||
|
||||
def log_invalid_auth(
|
||||
func: Callable[..., Awaitable[StreamResponse]]
|
||||
) -> Callable[..., Awaitable[StreamResponse]]:
|
||||
func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
|
||||
) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
|
||||
"""Decorate function to handle invalid auth or failed login attempts."""
|
||||
|
||||
async def handle_req(
|
||||
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
|
||||
) -> StreamResponse:
|
||||
view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Response:
|
||||
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
|
||||
resp = await func(view, request, *args, **kwargs)
|
||||
if resp.status >= HTTPStatus.BAD_REQUEST:
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
"""Decorator for view methods to help with data validation."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from functools import wraps
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from aiohttp import web
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
import voluptuous as vol
|
||||
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -33,33 +37,40 @@ class RequestDataValidator:
|
|||
self._allow_empty = allow_empty
|
||||
|
||||
def __call__(
|
||||
self, method: Callable[..., Awaitable[web.StreamResponse]]
|
||||
) -> Callable:
|
||||
self,
|
||||
method: Callable[
|
||||
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
|
||||
Awaitable[web.Response],
|
||||
],
|
||||
) -> Callable[
|
||||
Concatenate[_HassViewT, web.Request, _P],
|
||||
Coroutine[Any, Any, web.Response],
|
||||
]:
|
||||
"""Decorate a function."""
|
||||
|
||||
@wraps(method)
|
||||
async def wrapper(
|
||||
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
|
||||
) -> web.StreamResponse:
|
||||
view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> web.Response:
|
||||
"""Wrap a request handler with data validation."""
|
||||
data = None
|
||||
raw_data = None
|
||||
try:
|
||||
data = await request.json()
|
||||
raw_data = await request.json()
|
||||
except ValueError:
|
||||
if not self._allow_empty or (await request.content.read()) != b"":
|
||||
_LOGGER.error("Invalid JSON received")
|
||||
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
|
||||
data = {}
|
||||
raw_data = {}
|
||||
|
||||
try:
|
||||
kwargs["data"] = self._schema(data)
|
||||
data: dict[str, Any] = self._schema(raw_data)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error("Data does not match schema: %s", err)
|
||||
return view.json_message(
|
||||
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
result = await method(view, request, *args, **kwargs)
|
||||
result = await method(view, request, data, *args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
|
|||
|
||||
result = self._prepare_result_json(result)
|
||||
|
||||
return self.json(result) # pylint: disable=arguments-differ
|
||||
return self.json(result)
|
||||
|
||||
|
||||
class RepairsFlowResourceView(FlowManagerResourceView):
|
||||
|
@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
|
|||
raise Unauthorized(permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request, flow_id) # type: ignore[no-any-return]
|
||||
return await super().post(request, flow_id)
|
||||
|
|
|
@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
|
|||
|
||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||
async def post(
|
||||
self, request: web.Request, flow_id: str, data: dict[str, Any]
|
||||
self, request: web.Request, data: dict[str, Any], flow_id: str
|
||||
) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue