Use aiohttp.AppKey for http ban keys (#112657)
parent
7dcf275966
commit
eb8f8e1ae4
|
@ -11,7 +11,14 @@ import logging
|
|||
from socket import gethostbyaddr, herror
|
||||
from typing import Any, Concatenate, Final, ParamSpec, TypeVar
|
||||
|
||||
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
|
||||
from aiohttp.web import (
|
||||
AppKey,
|
||||
Application,
|
||||
Request,
|
||||
Response,
|
||||
StreamResponse,
|
||||
middleware,
|
||||
)
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -29,9 +36,11 @@ _P = ParamSpec("_P")
|
|||
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
|
||||
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
|
||||
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
|
||||
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
|
||||
KEY_BAN_MANAGER = AppKey["IpBanManager"]("ha_banned_ips_manager")
|
||||
KEY_FAILED_LOGIN_ATTEMPTS = AppKey[defaultdict[IPv4Address | IPv6Address, int]](
|
||||
"ha_failed_login_attempts"
|
||||
)
|
||||
KEY_LOGIN_THRESHOLD = AppKey[int]("ban_manager.ip_bans_lookup")
|
||||
|
||||
NOTIFICATION_ID_BAN: Final = "ip-ban"
|
||||
NOTIFICATION_ID_LOGIN: Final = "http-login"
|
||||
|
@ -48,7 +57,7 @@ SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
|
|||
def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
|
||||
"""Create IP Ban middleware for the app."""
|
||||
app.middlewares.append(ban_middleware)
|
||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict[IPv4Address | IPv6Address, int](int)
|
||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||
app[KEY_BAN_MANAGER] = IpBanManager(hass)
|
||||
|
||||
|
@ -64,13 +73,11 @@ async def ban_middleware(
|
|||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""IP Ban middleware."""
|
||||
ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER)
|
||||
if ban_manager is None:
|
||||
if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None:
|
||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||
return await handler(request)
|
||||
|
||||
ip_bans_lookup = ban_manager.ip_bans_lookup
|
||||
if ip_bans_lookup:
|
||||
if ip_bans_lookup := ban_manager.ip_bans_lookup:
|
||||
# Verify if IP is not banned
|
||||
ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
|
||||
if ip_address_ in ip_bans_lookup:
|
||||
|
@ -154,7 +161,7 @@ async def process_wrong_login(request: Request) -> None:
|
|||
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
|
||||
>= request.app[KEY_LOGIN_THRESHOLD]
|
||||
):
|
||||
ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER]
|
||||
ban_manager = request.app[KEY_BAN_MANAGER]
|
||||
_LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
|
||||
await ban_manager.async_add_ban(remote_addr)
|
||||
|
||||
|
@ -180,9 +187,7 @@ def process_success_login(request: Request) -> None:
|
|||
return
|
||||
|
||||
remote_addr = ip_address(request.remote) # type: ignore[arg-type]
|
||||
login_attempt_history: defaultdict[IPv4Address | IPv6Address, int] = app[
|
||||
KEY_FAILED_LOGIN_ATTEMPTS
|
||||
]
|
||||
login_attempt_history = app[KEY_FAILED_LOGIN_ATTEMPTS]
|
||||
if remote_addr in login_attempt_history and login_attempt_history[remote_addr] > 0:
|
||||
_LOGGER.debug(
|
||||
"Login success, reset failed login attempts counter from %s", remote_addr
|
||||
|
|
|
@ -15,7 +15,6 @@ from homeassistant.components.http.ban import (
|
|||
IP_BANS_FILE,
|
||||
KEY_BAN_MANAGER,
|
||||
KEY_FAILED_LOGIN_ATTEMPTS,
|
||||
IpBanManager,
|
||||
process_success_login,
|
||||
setup_bans,
|
||||
)
|
||||
|
@ -215,7 +214,7 @@ async def test_access_from_supervisor_ip(
|
|||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
manager: IpBanManager = app[KEY_BAN_MANAGER]
|
||||
manager = app[KEY_BAN_MANAGER]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.hassio.HassIO.get_resolution_info",
|
||||
|
@ -288,7 +287,7 @@ async def test_ip_bans_file_creation(
|
|||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
manager: IpBanManager = app[KEY_BAN_MANAGER]
|
||||
manager = app[KEY_BAN_MANAGER]
|
||||
m_open = mock_open()
|
||||
|
||||
with patch("homeassistant.components.http.ban.open", m_open, create=True):
|
||||
|
@ -408,7 +407,7 @@ async def test_single_ban_file_entry(
|
|||
setup_bans(hass, app, 2)
|
||||
mock_real_ip(app)("200.201.202.204")
|
||||
|
||||
manager: IpBanManager = app[KEY_BAN_MANAGER]
|
||||
manager = app[KEY_BAN_MANAGER]
|
||||
m_open = mock_open()
|
||||
|
||||
with patch("homeassistant.components.http.ban.open", m_open, create=True):
|
||||
|
|
Loading…
Reference in New Issue