Add strict connection for cloud (#115814)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/116085/head
parent
b520efb87a
commit
a4829330f6
|
@ -7,11 +7,14 @@ from collections.abc import Awaitable, Callable
|
|||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from urllib.parse import quote_plus, urljoin
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import alexa, google_assistant
|
||||
from homeassistant.components import alexa, google_assistant, http
|
||||
from homeassistant.components.auth import STRICT_CONNECTION_URL
|
||||
from homeassistant.components.http.auth import async_sign_path
|
||||
from homeassistant.config_entries import SOURCE_SYSTEM, ConfigEntry
|
||||
from homeassistant.const import (
|
||||
CONF_DESCRIPTION,
|
||||
|
@ -21,8 +24,21 @@ from homeassistant.const import (
|
|||
EVENT_HOMEASSISTANT_STOP,
|
||||
Platform,
|
||||
)
|
||||
from homeassistant.core import Event, HassJob, HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.core import (
|
||||
Event,
|
||||
HassJob,
|
||||
HomeAssistant,
|
||||
ServiceCall,
|
||||
ServiceResponse,
|
||||
SupportsResponse,
|
||||
callback,
|
||||
)
|
||||
from homeassistant.exceptions import (
|
||||
HomeAssistantError,
|
||||
ServiceValidationError,
|
||||
Unauthorized,
|
||||
UnknownUser,
|
||||
)
|
||||
from homeassistant.helpers import config_validation as cv, entityfilter
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.discovery import async_load_platform
|
||||
|
@ -31,6 +47,7 @@ from homeassistant.helpers.dispatcher import (
|
|||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import NoURLAvailableError, get_url
|
||||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
@ -265,18 +282,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _shutdown)
|
||||
|
||||
_remote_handle_prefs_updated(cloud)
|
||||
|
||||
async def _service_handler(service: ServiceCall) -> None:
|
||||
"""Handle service for cloud."""
|
||||
if service.service == SERVICE_REMOTE_CONNECT:
|
||||
await prefs.async_update(remote_enabled=True)
|
||||
elif service.service == SERVICE_REMOTE_DISCONNECT:
|
||||
await prefs.async_update(remote_enabled=False)
|
||||
|
||||
async_register_admin_service(hass, DOMAIN, SERVICE_REMOTE_CONNECT, _service_handler)
|
||||
async_register_admin_service(
|
||||
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
|
||||
)
|
||||
_setup_services(hass, prefs)
|
||||
|
||||
async def async_startup_repairs(_: datetime) -> None:
|
||||
"""Create repair issues after startup."""
|
||||
|
@ -395,3 +401,67 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
|
||||
@callback
|
||||
def _setup_services(hass: HomeAssistant, prefs: CloudPreferences) -> None:
|
||||
"""Set up services for cloud component."""
|
||||
|
||||
async def _service_handler(service: ServiceCall) -> None:
|
||||
"""Handle service for cloud."""
|
||||
if service.service == SERVICE_REMOTE_CONNECT:
|
||||
await prefs.async_update(remote_enabled=True)
|
||||
elif service.service == SERVICE_REMOTE_DISCONNECT:
|
||||
await prefs.async_update(remote_enabled=False)
|
||||
|
||||
async_register_admin_service(hass, DOMAIN, SERVICE_REMOTE_CONNECT, _service_handler)
|
||||
async_register_admin_service(
|
||||
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
|
||||
)
|
||||
|
||||
async def create_temporary_strict_connection_url(
|
||||
call: ServiceCall,
|
||||
) -> ServiceResponse:
|
||||
"""Create a strict connection url and return it."""
|
||||
# Copied form homeassistant/helpers/service.py#_async_admin_handler
|
||||
# as the helper supports no responses yet
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(context=call.context)
|
||||
if not user.is_admin:
|
||||
raise Unauthorized(context=call.context)
|
||||
|
||||
if prefs.strict_connection is http.const.StrictConnectionMode.DISABLED:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="strict_connection_not_enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
url = get_url(hass, require_cloud=True)
|
||||
except NoURLAvailableError as ex:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="no_url_available",
|
||||
) from ex
|
||||
|
||||
path = async_sign_path(
|
||||
hass,
|
||||
STRICT_CONNECTION_URL,
|
||||
timedelta(hours=1),
|
||||
use_content_user=True,
|
||||
)
|
||||
url = urljoin(url, path)
|
||||
|
||||
return {
|
||||
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
|
||||
"direct_url": url,
|
||||
}
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
create_temporary_strict_connection_url,
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
|
|
|
@ -250,6 +250,7 @@ class CloudClient(Interface):
|
|||
"enabled": self._prefs.remote_enabled,
|
||||
"instance_domain": self.cloud.remote.instance_domain,
|
||||
"alias": self.cloud.remote.alias,
|
||||
"strict_connection": self._prefs.strict_connection,
|
||||
},
|
||||
"version": HA_VERSION,
|
||||
"instance_id": self.prefs.instance_id,
|
||||
|
|
|
@ -33,6 +33,7 @@ PREF_GOOGLE_SETTINGS_VERSION = "google_settings_version"
|
|||
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
|
||||
PREF_GOOGLE_CONNECTED = "google_connected"
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE = "remote_allow_remote_enable"
|
||||
PREF_STRICT_CONNECTION = "strict_connection"
|
||||
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "JennyNeural")
|
||||
DEFAULT_DISABLE_2FA = False
|
||||
DEFAULT_ALEXA_REPORT_STATE = True
|
||||
|
|
|
@ -19,7 +19,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
|
|||
from hass_nabucasa.voice import TTS_VOICES
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components import http, websocket_api
|
||||
from homeassistant.components.alexa import (
|
||||
entities as alexa_entities,
|
||||
errors as alexa_errors,
|
||||
|
@ -46,6 +46,7 @@ from .const import (
|
|||
PREF_GOOGLE_REPORT_STATE,
|
||||
PREF_GOOGLE_SECURE_DEVICES_PIN,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
REQUEST_TIMEOUT,
|
||||
)
|
||||
|
@ -452,6 +453,9 @@ def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
|
|||
vol.Coerce(tuple), validate_language_voice
|
||||
),
|
||||
vol.Optional(PREF_REMOTE_ALLOW_REMOTE_ENABLE): bool,
|
||||
vol.Optional(PREF_STRICT_CONNECTION): vol.Coerce(
|
||||
http.const.StrictConnectionMode
|
||||
),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
{
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": "mdi:login-variant",
|
||||
"remote_connect": "mdi:cloud",
|
||||
"remote_disconnect": "mdi:cloud-off"
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
"name": "Home Assistant Cloud",
|
||||
"after_dependencies": ["assist_pipeline", "google_assistant", "alexa"],
|
||||
"codeowners": ["@home-assistant/cloud"],
|
||||
"dependencies": ["http", "repairs", "webhook"],
|
||||
"dependencies": ["auth", "http", "repairs", "webhook"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/cloud",
|
||||
"integration_type": "system",
|
||||
"iot_class": "cloud_push",
|
||||
|
|
|
@ -10,7 +10,7 @@ from hass_nabucasa.voice import MAP_VOICE
|
|||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import webhook
|
||||
from homeassistant.components import http, webhook
|
||||
from homeassistant.components.google_assistant.http import (
|
||||
async_get_users as async_get_google_assistant_users,
|
||||
)
|
||||
|
@ -44,6 +44,7 @@ from .const import (
|
|||
PREF_INSTANCE_ID,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
|
||||
PREF_REMOTE_DOMAIN,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
PREF_USERNAME,
|
||||
)
|
||||
|
@ -176,6 +177,7 @@ class CloudPreferences:
|
|||
google_settings_version: int | UndefinedType = UNDEFINED,
|
||||
google_connected: bool | UndefinedType = UNDEFINED,
|
||||
remote_allow_remote_enable: bool | UndefinedType = UNDEFINED,
|
||||
strict_connection: http.const.StrictConnectionMode | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Update user preferences."""
|
||||
prefs = {**self._prefs}
|
||||
|
@ -195,6 +197,7 @@ class CloudPreferences:
|
|||
(PREF_REMOTE_DOMAIN, remote_domain),
|
||||
(PREF_GOOGLE_CONNECTED, google_connected),
|
||||
(PREF_REMOTE_ALLOW_REMOTE_ENABLE, remote_allow_remote_enable),
|
||||
(PREF_STRICT_CONNECTION, strict_connection),
|
||||
):
|
||||
if value is not UNDEFINED:
|
||||
prefs[key] = value
|
||||
|
@ -242,6 +245,7 @@ class CloudPreferences:
|
|||
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE: self.remote_allow_remote_enable,
|
||||
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
|
||||
PREF_STRICT_CONNECTION: self.strict_connection,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -358,6 +362,17 @@ class CloudPreferences:
|
|||
"""
|
||||
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE) # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def strict_connection(self) -> http.const.StrictConnectionMode:
|
||||
"""Return the strict connection mode."""
|
||||
mode = self._prefs.get(
|
||||
PREF_STRICT_CONNECTION, http.const.StrictConnectionMode.DISABLED
|
||||
)
|
||||
|
||||
if not isinstance(mode, http.const.StrictConnectionMode):
|
||||
mode = http.const.StrictConnectionMode(mode)
|
||||
return mode # type: ignore[no-any-return]
|
||||
|
||||
async def get_cloud_user(self) -> str:
|
||||
"""Return ID of Home Assistant Cloud system user."""
|
||||
user = await self._load_cloud_user()
|
||||
|
@ -415,4 +430,5 @@ class CloudPreferences:
|
|||
PREF_REMOTE_DOMAIN: None,
|
||||
PREF_REMOTE_ALLOW_REMOTE_ENABLE: True,
|
||||
PREF_USERNAME: username,
|
||||
PREF_STRICT_CONNECTION: http.const.StrictConnectionMode.DISABLED,
|
||||
}
|
||||
|
|
|
@ -5,6 +5,14 @@
|
|||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"strict_connection_not_enabled": {
|
||||
"message": "Strict connection is not enabled for cloud requests"
|
||||
},
|
||||
"no_url_available": {
|
||||
"message": "No cloud URL available.\nPlease mark sure you have a working Remote UI."
|
||||
}
|
||||
},
|
||||
"system_health": {
|
||||
"info": {
|
||||
"can_reach_cert_server": "Reach Certificate Server",
|
||||
|
@ -73,6 +81,10 @@
|
|||
}
|
||||
},
|
||||
"services": {
|
||||
"create_temporary_strict_connection_url": {
|
||||
"name": "Create a temporary strict connection URL",
|
||||
"description": "Create a temporary strict connection URL, which can be used to login on another device."
|
||||
},
|
||||
"remote_connect": {
|
||||
"name": "Remote connect",
|
||||
"description": "Makes the instance UI accessible from outside of the local network by using Home Assistant Cloud."
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
"""Cloud util functions."""
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
|
||||
from homeassistant.components import http
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .client import CloudClient
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
def get_strict_connection_mode(hass: HomeAssistant) -> http.const.StrictConnectionMode:
|
||||
"""Get the strict connection mode."""
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
return cloud.client.prefs.strict_connection
|
|
@ -69,6 +69,7 @@ from homeassistant.util.json import json_loads
|
|||
from .auth import async_setup_auth, async_sign_path
|
||||
from .ban import setup_bans
|
||||
from .const import ( # noqa: F401
|
||||
DOMAIN,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
StrictConnectionMode,
|
||||
|
@ -82,8 +83,6 @@ from .security_filter import setup_security_filter
|
|||
from .static import CACHE_HEADERS, CachingStaticResource
|
||||
from .web_runner import HomeAssistantTCPSite
|
||||
|
||||
DOMAIN: Final = "http"
|
||||
|
||||
CONF_SERVER_HOST: Final = "server_host"
|
||||
CONF_SERVER_PORT: Final = "server_port"
|
||||
CONF_BASE_URL: Final = "base_url"
|
||||
|
@ -149,7 +148,7 @@ HTTP_SCHEMA: Final = vol.All(
|
|||
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
|
||||
vol.Optional(
|
||||
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
|
||||
): vol.In([e.value for e in StrictConnectionMode]),
|
||||
): vol.Coerce(StrictConnectionMode),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -628,7 +627,9 @@ def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
|
|||
)
|
||||
|
||||
try:
|
||||
url = get_url(hass, prefer_external=True, allow_internal=False)
|
||||
url = get_url(
|
||||
hass, prefer_external=True, allow_internal=False, allow_cloud=False
|
||||
)
|
||||
except NoURLAvailableError as ex:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
|
|
|
@ -25,6 +25,7 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
|||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import singleton
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
|
@ -32,6 +33,7 @@ from homeassistant.helpers.storage import Store
|
|||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
|
@ -50,8 +52,9 @@ STORAGE_VERSION = 1
|
|||
STORAGE_KEY = "http.auth"
|
||||
CONTENT_USER_NAME = "Home Assistant Content"
|
||||
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
|
||||
STRICT_CONNECTION_STATIC_PAGE_NAME = "strict_connection_static_page.html"
|
||||
STRICT_CONNECTION_STATIC_PAGE = os.path.join(
|
||||
os.path.dirname(__file__), "strict_connection_static_page.html"
|
||||
os.path.dirname(__file__), STRICT_CONNECTION_STATIC_PAGE_NAME
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,16 +159,10 @@ async def async_setup_auth(
|
|||
await store.async_save(data)
|
||||
|
||||
hass.data[STORAGE_KEY] = refresh_token.id
|
||||
strict_connection_static_file_content = None
|
||||
|
||||
if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
strict_connection_static_file_content = await hass.async_add_executor_job(
|
||||
read_static_page
|
||||
)
|
||||
# Load the static page content on setup
|
||||
await _read_strict_connection_static_page(hass)
|
||||
|
||||
@callback
|
||||
def async_validate_auth_header(request: Request) -> bool:
|
||||
|
@ -255,21 +252,36 @@ async def async_setup_auth(
|
|||
authenticated = True
|
||||
auth_type = "signed request"
|
||||
|
||||
if (
|
||||
not authenticated
|
||||
and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED
|
||||
and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH)
|
||||
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
|
||||
request
|
||||
)
|
||||
and (
|
||||
resp := _async_perform_action_on_non_local(
|
||||
request, strict_connection_static_file_content
|
||||
)
|
||||
)
|
||||
is not None
|
||||
if not authenticated and not request.path.startswith(
|
||||
STRICT_CONNECTION_EXCLUDED_PATH
|
||||
):
|
||||
return resp
|
||||
strict_connection_mode = strict_connection_mode_non_cloud
|
||||
strict_connection_func = (
|
||||
_async_perform_strict_connection_action_on_non_local
|
||||
)
|
||||
if is_cloud_connection(hass):
|
||||
from homeassistant.components.cloud.util import ( # pylint: disable=import-outside-toplevel
|
||||
get_strict_connection_mode,
|
||||
)
|
||||
|
||||
strict_connection_mode = get_strict_connection_mode(hass)
|
||||
strict_connection_func = _async_perform_strict_connection_action
|
||||
|
||||
if (
|
||||
strict_connection_mode is not StrictConnectionMode.DISABLED
|
||||
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
|
||||
request
|
||||
)
|
||||
and (
|
||||
resp := await strict_connection_func(
|
||||
hass,
|
||||
request,
|
||||
strict_connection_mode is StrictConnectionMode.STATIC_PAGE,
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
return resp
|
||||
|
||||
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
|
@ -286,17 +298,17 @@ async def async_setup_auth(
|
|||
app.middlewares.append(auth_middleware)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_perform_action_on_non_local(
|
||||
async def _async_perform_strict_connection_action_on_non_local(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
strict_connection_static_file_content: str | None,
|
||||
static_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action if the request is not local.
|
||||
|
||||
The function does the following:
|
||||
- Try to get the IP address of the request. If it fails, assume it's not local
|
||||
- If the request is local, return None (allow the request to continue)
|
||||
- If strict_connection_static_file_content is set, return a response with the content
|
||||
- If static_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
try:
|
||||
|
@ -308,10 +320,25 @@ def _async_perform_action_on_non_local(
|
|||
if ip_address_ and is_local(ip_address_):
|
||||
return None
|
||||
|
||||
_LOGGER.debug("Perform strict connection action for %s", ip_address_)
|
||||
if strict_connection_static_file_content:
|
||||
return await _async_perform_strict_connection_action(hass, request, static_page)
|
||||
|
||||
|
||||
async def _async_perform_strict_connection_action(
|
||||
hass: HomeAssistant,
|
||||
request: Request,
|
||||
static_page: bool,
|
||||
) -> StreamResponse | None:
|
||||
"""Perform strict connection mode action.
|
||||
|
||||
The function does the following:
|
||||
- If static_page is True, return a response with the content
|
||||
- Otherwise close the connection and raise an exception
|
||||
"""
|
||||
|
||||
_LOGGER.debug("Perform strict connection action for %s", request.remote)
|
||||
if static_page:
|
||||
return Response(
|
||||
text=strict_connection_static_file_content,
|
||||
text=await _read_strict_connection_static_page(hass),
|
||||
content_type="text/html",
|
||||
status=HTTPStatus.IM_A_TEAPOT,
|
||||
)
|
||||
|
@ -322,3 +349,14 @@ def _async_perform_action_on_non_local(
|
|||
|
||||
# We need to raise an exception to stop processing the request
|
||||
raise HTTPBadRequest
|
||||
|
||||
|
||||
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_STATIC_PAGE_NAME}")
|
||||
async def _read_strict_connection_static_page(hass: HomeAssistant) -> str:
|
||||
"""Read the strict connection static page from disk via executor."""
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
return await hass.async_add_executor_job(read_static_page)
|
||||
|
|
|
@ -5,6 +5,8 @@ from typing import Final
|
|||
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
|
||||
|
||||
DOMAIN: Final = "http"
|
||||
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
||||
|
|
|
@ -122,6 +122,7 @@ def get_url(
|
|||
require_current_request: bool = False,
|
||||
require_ssl: bool = False,
|
||||
require_standard_port: bool = False,
|
||||
require_cloud: bool = False,
|
||||
allow_internal: bool = True,
|
||||
allow_external: bool = True,
|
||||
allow_cloud: bool = True,
|
||||
|
@ -145,7 +146,7 @@ def get_url(
|
|||
|
||||
# Try finding an URL in the order specified
|
||||
for url_type in order:
|
||||
if allow_internal and url_type == TYPE_URL_INTERNAL:
|
||||
if allow_internal and url_type == TYPE_URL_INTERNAL and not require_cloud:
|
||||
with suppress(NoURLAvailableError):
|
||||
return _get_internal_url(
|
||||
hass,
|
||||
|
@ -155,7 +156,7 @@ def get_url(
|
|||
require_standard_port=require_standard_port,
|
||||
)
|
||||
|
||||
if allow_external and url_type == TYPE_URL_EXTERNAL:
|
||||
if require_cloud or (allow_external and url_type == TYPE_URL_EXTERNAL):
|
||||
with suppress(NoURLAvailableError):
|
||||
return _get_external_url(
|
||||
hass,
|
||||
|
@ -165,7 +166,10 @@ def get_url(
|
|||
require_current_request=require_current_request,
|
||||
require_ssl=require_ssl,
|
||||
require_standard_port=require_standard_port,
|
||||
require_cloud=require_cloud,
|
||||
)
|
||||
if require_cloud:
|
||||
raise NoURLAvailableError
|
||||
|
||||
# For current request, we accept loopback interfaces (e.g., 127.0.0.1),
|
||||
# the Supervisor hostname and localhost transparently
|
||||
|
@ -263,8 +267,12 @@ def _get_external_url(
|
|||
require_current_request: bool = False,
|
||||
require_ssl: bool = False,
|
||||
require_standard_port: bool = False,
|
||||
require_cloud: bool = False,
|
||||
) -> str:
|
||||
"""Get external URL of this instance."""
|
||||
if require_cloud:
|
||||
return _get_cloud_url(hass, require_current_request=require_current_request)
|
||||
|
||||
if prefer_cloud and allow_cloud:
|
||||
with suppress(NoURLAvailableError):
|
||||
return _get_cloud_url(hass)
|
||||
|
|
|
@ -152,6 +152,7 @@ IGNORE_VIOLATIONS = {
|
|||
("demo", "manual"),
|
||||
# This would be a circular dep
|
||||
("http", "network"),
|
||||
("http", "cloud"),
|
||||
# This would be a circular dep
|
||||
("zha", "homeassistant_hardware"),
|
||||
("zha", "homeassistant_sky_connect"),
|
||||
|
|
|
@ -24,6 +24,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
|
|||
ExposedEntities,
|
||||
async_expose_entity,
|
||||
)
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.const import CONTENT_TYPE_JSON, __version__ as HA_VERSION
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
@ -387,6 +388,7 @@ async def test_cloud_connection_info(hass: HomeAssistant) -> None:
|
|||
"connected": False,
|
||||
"enabled": False,
|
||||
"instance_domain": None,
|
||||
"strict_connection": StrictConnectionMode.DISABLED,
|
||||
},
|
||||
"version": HA_VERSION,
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
|
|||
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
|
||||
from homeassistant.components.google_assistant.helpers import GoogleEntity
|
||||
from homeassistant.components.homeassistant import exposed_entities
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
@ -782,6 +783,7 @@ async def test_websocket_status(
|
|||
"google_report_state": True,
|
||||
"remote_allow_remote_enable": True,
|
||||
"remote_enabled": False,
|
||||
"strict_connection": "disabled",
|
||||
"tts_default_voice": ["en-US", "JennyNeural"],
|
||||
},
|
||||
"alexa_entities": {
|
||||
|
@ -901,6 +903,7 @@ async def test_websocket_update_preferences(
|
|||
assert cloud.client.prefs.alexa_enabled
|
||||
assert cloud.client.prefs.google_secure_devices_pin is None
|
||||
assert cloud.client.prefs.remote_allow_remote_enable is True
|
||||
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DISABLED
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
|
@ -912,6 +915,7 @@ async def test_websocket_update_preferences(
|
|||
"google_secure_devices_pin": "1234",
|
||||
"tts_default_voice": ["en-GB", "RyanNeural"],
|
||||
"remote_allow_remote_enable": False,
|
||||
"strict_connection": StrictConnectionMode.DROP_CONNECTION,
|
||||
}
|
||||
)
|
||||
response = await client.receive_json()
|
||||
|
@ -922,6 +926,7 @@ async def test_websocket_update_preferences(
|
|||
assert cloud.client.prefs.google_secure_devices_pin == "1234"
|
||||
assert cloud.client.prefs.remote_allow_remote_enable is False
|
||||
assert cloud.client.prefs.tts_default_voice == ("en-GB", "RyanNeural")
|
||||
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DROP_CONNECTION
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from hass_nabucasa import Cloud
|
||||
import pytest
|
||||
|
@ -13,11 +14,16 @@ from homeassistant.components.cloud import (
|
|||
CloudNotConnected,
|
||||
async_get_or_create_cloudhook,
|
||||
)
|
||||
from homeassistant.components.cloud.const import DOMAIN, PREF_CLOUDHOOKS
|
||||
from homeassistant.components.cloud.const import (
|
||||
DOMAIN,
|
||||
PREF_CLOUDHOOKS,
|
||||
PREF_STRICT_CONNECTION,
|
||||
)
|
||||
from homeassistant.components.cloud.prefs import STORAGE_KEY
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import Unauthorized
|
||||
from homeassistant.exceptions import ServiceValidationError, Unauthorized
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, MockUser
|
||||
|
@ -295,3 +301,77 @@ async def test_cloud_logout(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.is_logged_in is False
|
||||
|
||||
|
||||
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
|
||||
mock_config_entry = MockConfigEntry(domain=DOMAIN)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
with pytest.raises(
|
||||
ServiceValidationError,
|
||||
match="Strict connection is not enabled for cloud requests",
|
||||
):
|
||||
await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode"),
|
||||
[
|
||||
StrictConnectionMode.DROP_CONNECTION,
|
||||
StrictConnectionMode.STATIC_PAGE,
|
||||
],
|
||||
)
|
||||
async def test_service_create_temporary_strict_connection(
|
||||
hass: HomeAssistant,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test service create_temporary_strict_connection_url."""
|
||||
mock_config_entry = MockConfigEntry(domain=DOMAIN)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: mode,
|
||||
}
|
||||
)
|
||||
|
||||
# No cloud url set
|
||||
with pytest.raises(ServiceValidationError, match="No cloud URL available"):
|
||||
await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
# Patch cloud url
|
||||
url = "https://example.com"
|
||||
with patch(
|
||||
"homeassistant.helpers.network._get_cloud_url",
|
||||
return_value=url,
|
||||
):
|
||||
response = await hass.services.async_call(
|
||||
cloud.DOMAIN,
|
||||
"create_temporary_strict_connection_url",
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
direct_url_prefix = f"{url}/auth/strict_connection/temp_token?authSig="
|
||||
assert response.pop("direct_url").startswith(direct_url_prefix)
|
||||
assert response.pop("url").startswith(
|
||||
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
|
||||
)
|
||||
assert response == {} # No more keys in response
|
||||
|
|
|
@ -6,8 +6,13 @@ from unittest.mock import ANY, MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||
from homeassistant.components.cloud.const import DOMAIN, PREF_TTS_DEFAULT_VOICE
|
||||
from homeassistant.components.cloud.const import (
|
||||
DOMAIN,
|
||||
PREF_STRICT_CONNECTION,
|
||||
PREF_TTS_DEFAULT_VOICE,
|
||||
)
|
||||
from homeassistant.components.cloud.prefs import STORAGE_KEY, CloudPreferences
|
||||
from homeassistant.components.http.const import StrictConnectionMode
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -174,3 +179,21 @@ async def test_tts_default_voice_legacy_gender(
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.client.prefs.tts_default_voice == (expected_language, voice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", list(StrictConnectionMode))
|
||||
async def test_strict_connection_convertion(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
hass_storage: dict[str, Any],
|
||||
mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test strict connection string value will be converted to the enum."""
|
||||
hass_storage[STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {PREF_STRICT_CONNECTION: mode.value},
|
||||
}
|
||||
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert cloud.client.prefs.strict_connection is mode
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
"""Test strict connection mode for cloud."""
|
||||
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from aiohttp import ServerDisconnectedError, web
|
||||
from aiohttp.test_utils import TestClient
|
||||
from aiohttp_session import get_session
|
||||
import pytest
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.models import RefreshToken
|
||||
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
|
||||
from homeassistant.components.cloud.const import PREF_STRICT_CONNECTION
|
||||
from homeassistant.components.http import KEY_HASS
|
||||
from homeassistant.components.http.auth import (
|
||||
STRICT_CONNECTION_STATIC_PAGE,
|
||||
async_setup_auth,
|
||||
async_sign_path,
|
||||
)
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
|
||||
from homeassistant.components.http.session import COOKIE_NAME, PREFIXED_COOKIE_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def refresh_token(hass: HomeAssistant, hass_access_token: str) -> RefreshToken:
|
||||
"""Return a refresh token."""
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert refresh_token
|
||||
session = hass.auth.session
|
||||
assert session._strict_connection_sessions == {}
|
||||
assert session._temp_sessions == {}
|
||||
return refresh_token
|
||||
|
||||
|
||||
@contextmanager
|
||||
def simulate_cloud_request() -> Generator[None, None, None]:
|
||||
"""Simulate a cloud request."""
|
||||
with patch(
|
||||
"hass_nabucasa.remote.is_cloud_request", Mock(get=Mock(return_value=True))
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_strict_connection(
|
||||
hass: HomeAssistant, refresh_token: RefreshToken
|
||||
) -> web.Application:
|
||||
"""Fixture to set up a web.Application."""
|
||||
|
||||
async def handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
|
||||
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
app.router.add_get("/", handler)
|
||||
|
||||
async def set_cookie(request: web.Request) -> web.Response:
|
||||
hass = request.app[KEY_HASS]
|
||||
# Clear all sessions
|
||||
hass.auth.session._temp_sessions.clear()
|
||||
hass.auth.session._strict_connection_sessions.clear()
|
||||
|
||||
if request.query["token"] == "refresh":
|
||||
await hass.auth.session.async_create_session(request, refresh_token)
|
||||
else:
|
||||
await hass.auth.session.async_create_temp_unauthorized_session(request)
|
||||
session = await get_session(request)
|
||||
return web.Response(text=session[SESSION_ID])
|
||||
|
||||
app.router.add_get("/test/cookie", set_cookie)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
async def set_up_fixture(
|
||||
hass: HomeAssistant,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
app_strict_connection: web.Application,
|
||||
cloud: MagicMock,
|
||||
socket_enabled: None,
|
||||
) -> TestClient:
|
||||
"""Set up the fixture."""
|
||||
|
||||
await async_setup_auth(hass, app_strict_connection, StrictConnectionMode.DISABLED)
|
||||
assert await async_setup_component(hass, "cloud", {"cloud": {}})
|
||||
await hass.async_block_till_done()
|
||||
return await aiohttp_client(app_strict_connection)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"strict_connection_mode", [e.value for e in StrictConnectionMode]
|
||||
)
|
||||
async def test_strict_connection_cloud_authenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
hass_access_token: str,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
refresh_token: RefreshToken,
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
) -> None:
|
||||
"""Test authenticated requests with strict connection."""
|
||||
assert hass.auth.session._strict_connection_sessions == {}
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
)
|
||||
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: strict_connection_mode,
|
||||
}
|
||||
)
|
||||
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
req = await client.get(
|
||||
"/", headers={"Authorization": f"Bearer {hass_access_token}"}
|
||||
)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
req = await client.get(signed_path)
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": True}
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
_: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled."""
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
refresh_token: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled and refresh token cookie."""
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with refresh token
|
||||
session_id = await _modify_cookie_for_cloud(client, "refresh")
|
||||
assert session._strict_connection_sessions == {session_id: refresh_token.id}
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
assert await req.json() == {"authenticated": False}
|
||||
|
||||
# Invalidate refresh token, which should also invalidate session
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
assert session._strict_connection_sessions == {}
|
||||
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
perform_unauthenticated_request: Callable[
|
||||
[HomeAssistant, TestClient], Awaitable[None]
|
||||
],
|
||||
_: RefreshToken,
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud enabled and temp cookie."""
|
||||
session = hass.auth.session
|
||||
|
||||
# set strict connection cookie with temp session
|
||||
assert session._temp_sessions == {}
|
||||
session_id = await _modify_cookie_for_cloud(client, "temp")
|
||||
assert session_id in session._temp_sessions
|
||||
with simulate_cloud_request():
|
||||
assert is_cloud_connection(hass)
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert await resp.json() == {"authenticated": False}
|
||||
|
||||
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
|
||||
await hass.async_block_till_done(wait_background_tasks=True)
|
||||
assert session._temp_sessions == {}
|
||||
|
||||
await perform_unauthenticated_request(hass, client)
|
||||
|
||||
|
||||
async def _drop_connection_unauthorized_request(
|
||||
_: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
with pytest.raises(ServerDisconnectedError):
|
||||
# unauthorized requests should raise ServerDisconnectedError
|
||||
await client.get("/")
|
||||
|
||||
|
||||
async def _static_page_unauthorized_request(
|
||||
hass: HomeAssistant, client: TestClient
|
||||
) -> None:
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.IM_A_TEAPOT
|
||||
|
||||
def read_static_page() -> str:
|
||||
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
assert await req.text() == await hass.async_add_executor_job(read_static_page)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_func",
|
||||
[
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests,
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token,
|
||||
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session,
|
||||
],
|
||||
ids=[
|
||||
"no cookie",
|
||||
"refresh token cookie",
|
||||
"temp session cookie",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("strict_connection_mode", "request_func"),
|
||||
[
|
||||
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
|
||||
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
|
||||
],
|
||||
ids=["drop connection", "static page"],
|
||||
)
|
||||
async def test_strict_connection_cloud_external_unauthenticated_requests(
|
||||
hass: HomeAssistant,
|
||||
client: TestClient,
|
||||
refresh_token: RefreshToken,
|
||||
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
|
||||
test_func: Callable[
|
||||
[
|
||||
HomeAssistant,
|
||||
TestClient,
|
||||
Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
RefreshToken,
|
||||
],
|
||||
Awaitable[None],
|
||||
],
|
||||
strict_connection_mode: StrictConnectionMode,
|
||||
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Test external unauthenticated requests with strict connection cloud."""
|
||||
await set_cloud_prefs(
|
||||
{
|
||||
PREF_STRICT_CONNECTION: strict_connection_mode,
|
||||
}
|
||||
)
|
||||
|
||||
await test_func(
|
||||
hass,
|
||||
client,
|
||||
request_func,
|
||||
refresh_token,
|
||||
)
|
||||
|
||||
|
||||
async def _modify_cookie_for_cloud(client: TestClient, token_type: str) -> str:
|
||||
"""Modify cookie for cloud."""
|
||||
# Cloud cookie has set secure=true and will not set on unsecure connection
|
||||
# As we test with unsecure connection, we need to set it manually
|
||||
# We get the session via http and modify the cookie name to the secure one
|
||||
session_id = await (await client.get(f"/test/cookie?token={token_type}")).text()
|
||||
cookie_jar = client.session.cookie_jar
|
||||
localhost = URL("http://127.0.0.1")
|
||||
cookie = cookie_jar.filter_cookies(localhost)[COOKIE_NAME].value
|
||||
assert cookie
|
||||
cookie_jar.clear()
|
||||
cookie_jar.update_cookies({PREFIXED_COOKIE_NAME: cookie}, localhost)
|
||||
return session_id
|
|
@ -362,6 +362,18 @@ async def test_get_url_external(hass: HomeAssistant) -> None:
|
|||
with pytest.raises(NoURLAvailableError):
|
||||
_get_external_url(hass, require_current_request=True, require_ssl=True)
|
||||
|
||||
with pytest.raises(NoURLAvailableError):
|
||||
_get_external_url(hass, require_cloud=True)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.cloud.async_remote_ui_url",
|
||||
return_value="https://example.nabu.casa",
|
||||
):
|
||||
hass.config.components.add("cloud")
|
||||
assert (
|
||||
_get_external_url(hass, require_cloud=True) == "https://example.nabu.casa"
|
||||
)
|
||||
|
||||
|
||||
async def test_get_cloud_url(hass: HomeAssistant) -> None:
|
||||
"""Test getting an instance URL when the user has set an external URL."""
|
||||
|
|
Loading…
Reference in New Issue