295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""Test strict connection mode for cloud."""
|
|
|
|
from collections.abc import Awaitable, Callable, Coroutine, Generator
|
|
from contextlib import contextmanager
|
|
from datetime import timedelta
|
|
from http import HTTPStatus
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
from aiohttp import ServerDisconnectedError, web
|
|
from aiohttp.test_utils import TestClient
|
|
from aiohttp_session import get_session
|
|
import pytest
|
|
from yarl import URL
|
|
|
|
from homeassistant.auth.models import RefreshToken
|
|
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
|
from homeassistant.components.cloud.const import PREF_STRICT_CONNECTION
|
|
from homeassistant.components.http import KEY_HASS
|
|
from homeassistant.components.http.auth import (
|
|
STRICT_CONNECTION_GUARD_PAGE,
|
|
async_setup_auth,
|
|
async_sign_path,
|
|
)
|
|
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
|
|
from homeassistant.components.http.session import COOKIE_NAME, PREFIXED_COOKIE_NAME
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers.network import is_cloud_connection
|
|
from homeassistant.setup import async_setup_component
|
|
from homeassistant.util.dt import utcnow
|
|
|
|
from tests.common import async_fire_time_changed
|
|
from tests.typing import ClientSessionGenerator
|
|
|
|
|
|
@pytest.fixture
|
|
async def refresh_token(hass: HomeAssistant, hass_access_token: str) -> RefreshToken:
|
|
"""Return a refresh token."""
|
|
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
|
assert refresh_token
|
|
session = hass.auth.session
|
|
assert session._strict_connection_sessions == {}
|
|
assert session._temp_sessions == {}
|
|
return refresh_token
|
|
|
|
|
|
@contextmanager
|
|
def simulate_cloud_request() -> Generator[None, None, None]:
|
|
"""Simulate a cloud request."""
|
|
with patch(
|
|
"hass_nabucasa.remote.is_cloud_request", Mock(get=Mock(return_value=True))
|
|
):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def app_strict_connection(
|
|
hass: HomeAssistant, refresh_token: RefreshToken
|
|
) -> web.Application:
|
|
"""Fixture to set up a web.Application."""
|
|
|
|
async def handler(request):
|
|
"""Return if request was authenticated."""
|
|
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
|
|
|
|
app = web.Application()
|
|
app[KEY_HASS] = hass
|
|
app.router.add_get("/", handler)
|
|
|
|
async def set_cookie(request: web.Request) -> web.Response:
|
|
hass = request.app[KEY_HASS]
|
|
# Clear all sessions
|
|
hass.auth.session._temp_sessions.clear()
|
|
hass.auth.session._strict_connection_sessions.clear()
|
|
|
|
if request.query["token"] == "refresh":
|
|
await hass.auth.session.async_create_session(request, refresh_token)
|
|
else:
|
|
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
|
session = await get_session(request)
|
|
return web.Response(text=session[SESSION_ID])
|
|
|
|
app.router.add_get("/test/cookie", set_cookie)
|
|
return app
|
|
|
|
|
|
@pytest.fixture(name="client")
|
|
async def set_up_fixture(
|
|
hass: HomeAssistant,
|
|
aiohttp_client: ClientSessionGenerator,
|
|
app_strict_connection: web.Application,
|
|
cloud: MagicMock,
|
|
socket_enabled: None,
|
|
) -> TestClient:
|
|
"""Set up the fixture."""
|
|
|
|
await async_setup_auth(hass, app_strict_connection, StrictConnectionMode.DISABLED)
|
|
assert await async_setup_component(hass, "cloud", {"cloud": {}})
|
|
await hass.async_block_till_done()
|
|
return await aiohttp_client(app_strict_connection)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
|
)
|
|
async def test_strict_connection_cloud_authenticated_requests(
|
|
hass: HomeAssistant,
|
|
client: TestClient,
|
|
hass_access_token: str,
|
|
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
|
refresh_token: RefreshToken,
|
|
strict_connection_mode: StrictConnectionMode,
|
|
) -> None:
|
|
"""Test authenticated requests with strict connection."""
|
|
assert hass.auth.session._strict_connection_sessions == {}
|
|
|
|
signed_path = async_sign_path(
|
|
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
|
)
|
|
|
|
await set_cloud_prefs(
|
|
{
|
|
PREF_STRICT_CONNECTION: strict_connection_mode,
|
|
}
|
|
)
|
|
|
|
with simulate_cloud_request():
|
|
assert is_cloud_connection(hass)
|
|
req = await client.get(
|
|
"/", headers={"Authorization": f"Bearer {hass_access_token}"}
|
|
)
|
|
assert req.status == HTTPStatus.OK
|
|
assert await req.json() == {"authenticated": True}
|
|
req = await client.get(signed_path)
|
|
assert req.status == HTTPStatus.OK
|
|
assert await req.json() == {"authenticated": True}
|
|
|
|
|
|
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests(
|
|
hass: HomeAssistant,
|
|
client: TestClient,
|
|
perform_unauthenticated_request: Callable[
|
|
[HomeAssistant, TestClient], Awaitable[None]
|
|
],
|
|
_: RefreshToken,
|
|
) -> None:
|
|
"""Test external unauthenticated requests with strict connection cloud enabled."""
|
|
with simulate_cloud_request():
|
|
assert is_cloud_connection(hass)
|
|
await perform_unauthenticated_request(hass, client)
|
|
|
|
|
|
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token(
|
|
hass: HomeAssistant,
|
|
client: TestClient,
|
|
perform_unauthenticated_request: Callable[
|
|
[HomeAssistant, TestClient], Awaitable[None]
|
|
],
|
|
refresh_token: RefreshToken,
|
|
) -> None:
|
|
"""Test external unauthenticated requests with strict connection cloud enabled and refresh token cookie."""
|
|
session = hass.auth.session
|
|
|
|
# set strict connection cookie with refresh token
|
|
session_id = await _modify_cookie_for_cloud(client, "refresh")
|
|
assert session._strict_connection_sessions == {session_id: refresh_token.id}
|
|
with simulate_cloud_request():
|
|
assert is_cloud_connection(hass)
|
|
req = await client.get("/")
|
|
assert req.status == HTTPStatus.OK
|
|
assert await req.json() == {"authenticated": False}
|
|
|
|
# Invalidate refresh token, which should also invalidate session
|
|
hass.auth.async_remove_refresh_token(refresh_token)
|
|
assert session._strict_connection_sessions == {}
|
|
|
|
await perform_unauthenticated_request(hass, client)
|
|
|
|
|
|
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session(
|
|
hass: HomeAssistant,
|
|
client: TestClient,
|
|
perform_unauthenticated_request: Callable[
|
|
[HomeAssistant, TestClient], Awaitable[None]
|
|
],
|
|
_: RefreshToken,
|
|
) -> None:
|
|
"""Test external unauthenticated requests with strict connection cloud enabled and temp cookie."""
|
|
session = hass.auth.session
|
|
|
|
# set strict connection cookie with temp session
|
|
assert session._temp_sessions == {}
|
|
session_id = await _modify_cookie_for_cloud(client, "temp")
|
|
assert session_id in session._temp_sessions
|
|
with simulate_cloud_request():
|
|
assert is_cloud_connection(hass)
|
|
resp = await client.get("/")
|
|
assert resp.status == HTTPStatus.OK
|
|
assert await resp.json() == {"authenticated": False}
|
|
|
|
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
|
|
await hass.async_block_till_done(wait_background_tasks=True)
|
|
assert session._temp_sessions == {}
|
|
|
|
await perform_unauthenticated_request(hass, client)
|
|
|
|
|
|
async def _drop_connection_unauthorized_request(
|
|
_: HomeAssistant, client: TestClient
|
|
) -> None:
|
|
with pytest.raises(ServerDisconnectedError):
|
|
# unauthorized requests should raise ServerDisconnectedError
|
|
await client.get("/")
|
|
|
|
|
|
async def _guard_page_unauthorized_request(
|
|
hass: HomeAssistant, client: TestClient
|
|
) -> None:
|
|
req = await client.get("/")
|
|
assert req.status == HTTPStatus.IM_A_TEAPOT
|
|
|
|
def read_guard_page() -> str:
|
|
with open(STRICT_CONNECTION_GUARD_PAGE, encoding="utf-8") as file:
|
|
return file.read()
|
|
|
|
assert await req.text() == await hass.async_add_executor_job(read_guard_page)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_func",
|
|
[
|
|
_test_strict_connection_cloud_enabled_external_unauthenticated_requests,
|
|
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token,
|
|
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session,
|
|
],
|
|
ids=[
|
|
"no cookie",
|
|
"refresh token cookie",
|
|
"temp session cookie",
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
("strict_connection_mode", "request_func"),
|
|
[
|
|
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
|
(StrictConnectionMode.GUARD_PAGE, _guard_page_unauthorized_request),
|
|
],
|
|
ids=["drop connection", "static page"],
|
|
)
|
|
async def test_strict_connection_cloud_external_unauthenticated_requests(
|
|
hass: HomeAssistant,
|
|
client: TestClient,
|
|
refresh_token: RefreshToken,
|
|
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
|
test_func: Callable[
|
|
[
|
|
HomeAssistant,
|
|
TestClient,
|
|
Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
|
RefreshToken,
|
|
],
|
|
Awaitable[None],
|
|
],
|
|
strict_connection_mode: StrictConnectionMode,
|
|
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
|
) -> None:
|
|
"""Test external unauthenticated requests with strict connection cloud."""
|
|
await set_cloud_prefs(
|
|
{
|
|
PREF_STRICT_CONNECTION: strict_connection_mode,
|
|
}
|
|
)
|
|
|
|
await test_func(
|
|
hass,
|
|
client,
|
|
request_func,
|
|
refresh_token,
|
|
)
|
|
|
|
|
|
async def _modify_cookie_for_cloud(client: TestClient, token_type: str) -> str:
|
|
"""Modify cookie for cloud."""
|
|
# Cloud cookie has set secure=true and will not set on unsecure connection
|
|
# As we test with unsecure connection, we need to set it manually
|
|
# We get the session via http and modify the cookie name to the secure one
|
|
session_id = await (await client.get(f"/test/cookie?token={token_type}")).text()
|
|
cookie_jar = client.session.cookie_jar
|
|
localhost = URL("http://127.0.0.1")
|
|
cookie = cookie_jar.filter_cookies(localhost)[COOKIE_NAME].value
|
|
assert cookie
|
|
cookie_jar.clear()
|
|
cookie_jar.update_cookies({PREFIXED_COOKIE_NAME: cookie}, localhost)
|
|
return session_id
|