From 677abcd48470042b09e47e9979c5361cc8e59490 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 22 Sep 2021 14:17:04 -0700 Subject: [PATCH] Allow confirming local push notifications (#54947) * Allow confirming local push notifications * Fix from Zac * Add tests --- .../components/mobile_app/__init__.py | 59 +------- homeassistant/components/mobile_app/notify.py | 97 ++++++------- .../mobile_app/push_notification.py | 90 ++++++++++++ .../components/mobile_app/websocket_api.py | 121 +++++++++++++++++ .../components/websocket_api/connection.py | 4 +- .../components/websocket_api/http.py | 2 +- tests/components/mobile_app/test_notify.py | 128 ++++++++++++++++++ 7 files changed, 397 insertions(+), 104 deletions(-) create mode 100644 homeassistant/components/mobile_app/push_notification.py create mode 100644 homeassistant/components/mobile_app/websocket_api.py diff --git a/homeassistant/components/mobile_app/__init__.py b/homeassistant/components/mobile_app/__init__.py index 1fc5be2a890..73775f23e6d 100644 --- a/homeassistant/components/mobile_app/__init__.py +++ b/homeassistant/components/mobile_app/__init__.py @@ -1,25 +1,23 @@ """Integrates Native Apps to Home Assistant.""" from contextlib import suppress -import voluptuous as vol - -from homeassistant.components import cloud, notify as hass_notify, websocket_api +from homeassistant.components import cloud, notify as hass_notify from homeassistant.components.webhook import ( async_register as webhook_register, async_unregister as webhook_unregister, ) from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr, discovery from homeassistant.helpers.typing import ConfigType +from . import websocket_api from .const import ( ATTR_DEVICE_NAME, ATTR_MANUFACTURER, ATTR_MODEL, ATTR_OS_VERSION, CONF_CLOUDHOOK_URL, - CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_DELETED_IDS, DATA_DEVICES, @@ -66,7 +64,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: discovery.async_load_platform(hass, "notify", DOMAIN, {}, config) ) - websocket_api.async_register_command(hass, handle_push_notification_channel) + websocket_api.async_setup_commands(hass) return True @@ -127,52 +125,3 @@ async def async_remove_entry(hass, entry): if CONF_CLOUDHOOK_URL in entry.data: with suppress(cloud.CloudNotAvailable): await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID]) - - -@callback -@websocket_api.websocket_command( - { - vol.Required("type"): "mobile_app/push_notification_channel", - vol.Required("webhook_id"): str, - } -) -def handle_push_notification_channel(hass, connection, msg): - """Set up a direct push notification channel.""" - webhook_id = msg["webhook_id"] - - # Validate that the webhook ID is registered to the user of the websocket connection - config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(webhook_id) - - if config_entry is None: - connection.send_error( - msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found" - ) - return - - if config_entry.data[CONF_USER_ID] != connection.user.id: - connection.send_error( - msg["id"], - websocket_api.ERR_UNAUTHORIZED, - "User not linked to this webhook ID", - ) - return - - registered_channels = hass.data[DOMAIN][DATA_PUSH_CHANNEL] - - if webhook_id in registered_channels: - registered_channels.pop(webhook_id) - - @callback - def forward_push_notification(data): - """Forward events to websocket.""" - connection.send_message(websocket_api.messages.event_message(msg["id"], data)) - - @callback - def unsub(): - # pylint: disable=comparison-with-callable - if registered_channels.get(webhook_id) == forward_push_notification: - registered_channels.pop(webhook_id) - - registered_channels[webhook_id] = forward_push_notification - connection.subscriptions[msg["id"]] = unsub - connection.send_result(msg["id"]) diff --git a/homeassistant/components/mobile_app/notify.py b/homeassistant/components/mobile_app/notify.py index c98fdeb9999..025880d8107 100644 --- a/homeassistant/components/mobile_app/notify.py +++ b/homeassistant/components/mobile_app/notify.py @@ -1,5 +1,6 @@ """Support for mobile_app push notifications.""" import asyncio +from functools import partial import logging import aiohttp @@ -124,61 +125,65 @@ class MobileAppNotificationService(BaseNotificationService): for target in targets: if target in local_push_channels: - local_push_channels[target](data) + local_push_channels[target].async_send_notification( + data, partial(self._async_send_remote_message_target, target) + ) continue - entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target] - entry_data = entry.data + await self._async_send_remote_message_target(target, data) - app_data = entry_data[ATTR_APP_DATA] - push_token = app_data[ATTR_PUSH_TOKEN] - push_url = app_data[ATTR_PUSH_URL] + async def _async_send_remote_message_target(self, target, data): + """Send a message to a target.""" + entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target] + entry_data = entry.data - target_data = dict(data) - target_data[ATTR_PUSH_TOKEN] = push_token + app_data = entry_data[ATTR_APP_DATA] + push_token = app_data[ATTR_PUSH_TOKEN] + push_url = app_data[ATTR_PUSH_URL] - reg_info = { - ATTR_APP_ID: entry_data[ATTR_APP_ID], - ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION], - } - if ATTR_OS_VERSION in entry_data: - reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION] + target_data = dict(data) + target_data[ATTR_PUSH_TOKEN] = push_token - target_data["registration_info"] = reg_info + reg_info = { + ATTR_APP_ID: entry_data[ATTR_APP_ID], + ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION], + } + if ATTR_OS_VERSION in entry_data: + reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION] - try: - with async_timeout.timeout(10): - response = await async_get_clientsession(self._hass).post( - push_url, json=target_data - ) - result = await response.json() + target_data["registration_info"] = reg_info - if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED): - log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result) - continue - - fallback_error = result.get("errorMessage", "Unknown error") - fallback_message = ( - f"Internal server error, please try again later: {fallback_error}" + try: + with async_timeout.timeout(10): + response = await async_get_clientsession(self._hass).post( + push_url, json=target_data ) - message = result.get("message", fallback_message) + result = await response.json() - if "message" in result: - if message[-1] not in [".", "?", "!"]: - message += "." - message += ( - " This message is generated externally to Home Assistant." - ) + if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED): + log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result) + return - if response.status == HTTP_TOO_MANY_REQUESTS: - _LOGGER.warning(message) - log_rate_limits( - self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING - ) - else: - _LOGGER.error(message) + fallback_error = result.get("errorMessage", "Unknown error") + fallback_message = ( + f"Internal server error, please try again later: {fallback_error}" + ) + message = result.get("message", fallback_message) - except asyncio.TimeoutError: - _LOGGER.error("Timeout sending notification to %s", push_url) - except aiohttp.ClientError as err: - _LOGGER.error("Error sending notification to %s: %r", push_url, err) + if "message" in result: + if message[-1] not in [".", "?", "!"]: + message += "." + message += " This message is generated externally to Home Assistant." + + if response.status == HTTP_TOO_MANY_REQUESTS: + _LOGGER.warning(message) + log_rate_limits( + self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING + ) + else: + _LOGGER.error(message) + + except asyncio.TimeoutError: + _LOGGER.error("Timeout sending notification to %s", push_url) + except aiohttp.ClientError as err: + _LOGGER.error("Error sending notification to %s: %r", push_url, err) diff --git a/homeassistant/components/mobile_app/push_notification.py b/homeassistant/components/mobile_app/push_notification.py new file mode 100644 index 00000000000..1cc5bac5d1c --- /dev/null +++ b/homeassistant/components/mobile_app/push_notification.py @@ -0,0 +1,90 @@ +"""Push notification handling.""" +import asyncio +from typing import Callable + +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.event import async_call_later +from homeassistant.util.uuid import random_uuid_hex + +PUSH_CONFIRM_TIMEOUT = 10 # seconds + + +class PushChannel: + """Class that represents a push channel.""" + + def __init__( + self, + hass: HomeAssistant, + webhook_id: str, + support_confirm: bool, + send_message: Callable[[dict], None], + on_teardown: Callable[[], None], + ) -> None: + """Initialize a local push channel.""" + self.hass = hass + self.webhook_id = webhook_id + self.support_confirm = support_confirm + self._send_message = send_message + self.on_teardown = on_teardown + self.pending_confirms = {} + + @callback + def async_send_notification(self, data, fallback_send): + """Send a push notification.""" + if not self.support_confirm: + self._send_message(data) + return + + confirm_id = random_uuid_hex() + data["hass_confirm_id"] = confirm_id + + async def handle_push_failed(_=None): + """Handle a failed local push notification.""" + # Remove this handler from the pending dict + # If it didn't exist we hit a race condition between call_later and another + # push failing and tearing down the connection. + if self.pending_confirms.pop(confirm_id, None) is None: + return + + # Drop local channel if it's still open + if self.on_teardown is not None: + await self.async_teardown() + + await fallback_send(data) + + self.pending_confirms[confirm_id] = { + "unsub_scheduled_push_failed": async_call_later( + self.hass, PUSH_CONFIRM_TIMEOUT, handle_push_failed + ), + "handle_push_failed": handle_push_failed, + } + self._send_message(data) + + @callback + def async_confirm_notification(self, confirm_id) -> bool: + """Confirm a push notification. + + Returns if confirmation successful. + """ + if confirm_id not in self.pending_confirms: + return False + + self.pending_confirms.pop(confirm_id)["unsub_scheduled_push_failed"]() + return True + + async def async_teardown(self): + """Tear down this channel.""" + # Tear down is in progress + if self.on_teardown is None: + return + + self.on_teardown() + self.on_teardown = None + + cancel_pending_local_tasks = [ + actions["handle_push_failed"]() + for actions in self.pending_confirms.values() + ] + + if cancel_pending_local_tasks: + await asyncio.gather(*cancel_pending_local_tasks) diff --git a/homeassistant/components/mobile_app/websocket_api.py b/homeassistant/components/mobile_app/websocket_api.py new file mode 100644 index 00000000000..4b0863d77af --- /dev/null +++ b/homeassistant/components/mobile_app/websocket_api.py @@ -0,0 +1,121 @@ +"""Mobile app websocket API.""" +from __future__ import annotations + +from functools import wraps + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.core import callback + +from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN +from .push_notification import PushChannel + + +@callback +def async_setup_commands(hass): + """Set up the mobile app websocket API.""" + websocket_api.async_register_command(hass, handle_push_notification_channel) + websocket_api.async_register_command(hass, handle_push_notification_confirm) + + +def _ensure_webhook_access(func): + """Decorate WS function to ensure user owns the webhook ID.""" + + @callback + @wraps(func) + def with_webhook_access(hass, connection, msg): + # Validate that the webhook ID is registered to the user of the websocket connection + config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"]) + + if config_entry is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found" + ) + return + + if config_entry.data[CONF_USER_ID] != connection.user.id: + connection.send_error( + msg["id"], + websocket_api.ERR_UNAUTHORIZED, + "User not linked to this webhook ID", + ) + return + + func(hass, connection, msg) + + return with_webhook_access + + +@callback +@_ensure_webhook_access +@websocket_api.websocket_command( + { + vol.Required("type"): "mobile_app/push_notification_confirm", + vol.Required("webhook_id"): str, + vol.Required("confirm_id"): str, + } +) +def handle_push_notification_confirm(hass, connection, msg): + """Confirm receipt of a push notification.""" + channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get( + msg["webhook_id"] + ) + if channel is None: + connection.send_error( + msg["id"], + websocket_api.ERR_NOT_FOUND, + "Push notification channel not found", + ) + return + + if channel.async_confirm_notification(msg["confirm_id"]): + connection.send_result(msg["id"]) + else: + connection.send_error( + msg["id"], + websocket_api.ERR_NOT_FOUND, + "Push notification channel not found", + ) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "mobile_app/push_notification_channel", + vol.Required("webhook_id"): str, + vol.Optional("support_confirm", default=False): bool, + } +) +@_ensure_webhook_access +@websocket_api.async_response +async def handle_push_notification_channel(hass, connection, msg): + """Set up a direct push notification channel.""" + webhook_id = msg["webhook_id"] + registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL] + + if webhook_id in registered_channels: + await registered_channels[webhook_id].async_teardown() + + @callback + def on_channel_teardown(): + """Handle teardown.""" + if registered_channels.get(webhook_id) == channel: + registered_channels.pop(webhook_id) + + # Remove subscription from connection if still exists + connection.subscriptions.pop(msg["id"], None) + + channel = registered_channels[webhook_id] = PushChannel( + hass, + webhook_id, + msg["support_confirm"], + lambda data: connection.send_message( + websocket_api.messages.event_message(msg["id"], data) + ), + on_channel_teardown, + ) + + connection.subscriptions[msg["id"]] = lambda: hass.async_create_task( + channel.async_teardown() + ) + connection.send_result(msg["id"]) diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 62c21ef5894..0d3bd5fdf4d 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -104,8 +104,8 @@ class ActiveConnection: self.last_id = cur_id @callback - def async_close(self) -> None: - """Close down connection.""" + def async_handle_close(self) -> None: + """Handle closing down connection.""" for unsub in self.subscriptions.values(): unsub() diff --git a/homeassistant/components/websocket_api/http.py b/homeassistant/components/websocket_api/http.py index d51eff7459e..aa6a74b27ec 100644 --- a/homeassistant/components/websocket_api/http.py +++ b/homeassistant/components/websocket_api/http.py @@ -231,7 +231,7 @@ class WebSocketHandler: unsub_stop() if connection is not None: - connection.async_close() + connection.async_handle_close() try: self._to_write.put_nowait(None) diff --git a/tests/components/mobile_app/test_notify.py b/tests/components/mobile_app/test_notify.py index 1e3b999d5f5..c0e1b4c2a85 100644 --- a/tests/components/mobile_app/test_notify.py +++ b/tests/components/mobile_app/test_notify.py @@ -1,5 +1,6 @@ """Notify platform tests for mobile_app.""" from datetime import datetime, timedelta +from unittest.mock import patch import pytest @@ -204,3 +205,130 @@ async def test_notify_ws_works( "code": "unauthorized", "message": "User not linked to this webhook ID", } + + +async def test_notify_ws_confirming_works( + hass, aioclient_mock, setup_push_receiver, hass_ws_client +): + """Test notify confirming works.""" + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "mobile_app/push_notification_channel", + "webhook_id": "mock-webhook_id", + "support_confirm": True, + } + ) + + sub_result = await client.receive_json() + assert sub_result["success"] + + # Sent a message that will be delivered locally + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world"}, blocking=True + ) + + msg_result = await client.receive_json() + confirm_id = msg_result["event"].pop("hass_confirm_id") + assert confirm_id is not None + assert msg_result["event"] == {"message": "Hello world"} + + # Try to confirm with incorrect confirm ID + await client.send_json( + { + "id": 6, + "type": "mobile_app/push_notification_confirm", + "webhook_id": "mock-webhook_id", + "confirm_id": "incorrect-confirm-id", + } + ) + + result = await client.receive_json() + assert not result["success"] + assert result["error"] == { + "code": "not_found", + "message": "Push notification channel not found", + } + + # Confirm with correct confirm ID + await client.send_json( + { + "id": 7, + "type": "mobile_app/push_notification_confirm", + "webhook_id": "mock-webhook_id", + "confirm_id": confirm_id, + } + ) + + result = await client.receive_json() + assert result["success"] + + # Drop local push channel and try to confirm another message + await client.send_json( + { + "id": 8, + "type": "unsubscribe_events", + "subscription": 5, + } + ) + sub_result = await client.receive_json() + assert sub_result["success"] + + await client.send_json( + { + "id": 9, + "type": "mobile_app/push_notification_confirm", + "webhook_id": "mock-webhook_id", + "confirm_id": confirm_id, + } + ) + + result = await client.receive_json() + assert not result["success"] + assert result["error"] == { + "code": "not_found", + "message": "Push notification channel not found", + } + + +async def test_notify_ws_not_confirming( + hass, aioclient_mock, setup_push_receiver, hass_ws_client +): + """Test we go via cloud when failed to confirm.""" + client = await hass_ws_client(hass) + + await client.send_json( + { + "id": 5, + "type": "mobile_app/push_notification_channel", + "webhook_id": "mock-webhook_id", + "support_confirm": True, + } + ) + + sub_result = await client.receive_json() + assert sub_result["success"] + + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world 1"}, blocking=True + ) + + with patch( + "homeassistant.components.mobile_app.push_notification.PUSH_CONFIRM_TIMEOUT", 0 + ): + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world 2"}, blocking=True + ) + await hass.async_block_till_done() + + # When we fail, all unconfirmed ones and failed one are sent via cloud + assert len(aioclient_mock.mock_calls) == 2 + + # All future ones also go via cloud + assert await hass.services.async_call( + "notify", "mobile_app_test", {"message": "Hello world 3"}, blocking=True + ) + + assert len(aioclient_mock.mock_calls) == 3