Allow confirming local push notifications (#54947)

* Allow confirming local push notifications

* Fix from Zac

* Add tests
pull/56546/head
Paulus Schoutsen 2021-09-22 14:17:04 -07:00 committed by GitHub
parent f77e93ceeb
commit 677abcd484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 397 additions and 104 deletions

View File

@ -1,25 +1,23 @@
"""Integrates Native Apps to Home Assistant."""
from contextlib import suppress
import voluptuous as vol
from homeassistant.components import cloud, notify as hass_notify, websocket_api
from homeassistant.components import cloud, notify as hass_notify
from homeassistant.components.webhook import (
async_register as webhook_register,
async_unregister as webhook_unregister,
)
from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, discovery
from homeassistant.helpers.typing import ConfigType
from . import websocket_api
from .const import (
ATTR_DEVICE_NAME,
ATTR_MANUFACTURER,
ATTR_MODEL,
ATTR_OS_VERSION,
CONF_CLOUDHOOK_URL,
CONF_USER_ID,
DATA_CONFIG_ENTRIES,
DATA_DELETED_IDS,
DATA_DEVICES,
@ -66,7 +64,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
discovery.async_load_platform(hass, "notify", DOMAIN, {}, config)
)
websocket_api.async_register_command(hass, handle_push_notification_channel)
websocket_api.async_setup_commands(hass)
return True
@ -127,52 +125,3 @@ async def async_remove_entry(hass, entry):
if CONF_CLOUDHOOK_URL in entry.data:
with suppress(cloud.CloudNotAvailable):
await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID])
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_channel",
vol.Required("webhook_id"): str,
}
)
def handle_push_notification_channel(hass, connection, msg):
"""Set up a direct push notification channel."""
webhook_id = msg["webhook_id"]
# Validate that the webhook ID is registered to the user of the websocket connection
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(webhook_id)
if config_entry is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
)
return
if config_entry.data[CONF_USER_ID] != connection.user.id:
connection.send_error(
msg["id"],
websocket_api.ERR_UNAUTHORIZED,
"User not linked to this webhook ID",
)
return
registered_channels = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
if webhook_id in registered_channels:
registered_channels.pop(webhook_id)
@callback
def forward_push_notification(data):
"""Forward events to websocket."""
connection.send_message(websocket_api.messages.event_message(msg["id"], data))
@callback
def unsub():
# pylint: disable=comparison-with-callable
if registered_channels.get(webhook_id) == forward_push_notification:
registered_channels.pop(webhook_id)
registered_channels[webhook_id] = forward_push_notification
connection.subscriptions[msg["id"]] = unsub
connection.send_result(msg["id"])

View File

@ -1,5 +1,6 @@
"""Support for mobile_app push notifications."""
import asyncio
from functools import partial
import logging
import aiohttp
@ -124,61 +125,65 @@ class MobileAppNotificationService(BaseNotificationService):
for target in targets:
if target in local_push_channels:
local_push_channels[target](data)
local_push_channels[target].async_send_notification(
data, partial(self._async_send_remote_message_target, target)
)
continue
entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target]
entry_data = entry.data
await self._async_send_remote_message_target(target, data)
app_data = entry_data[ATTR_APP_DATA]
push_token = app_data[ATTR_PUSH_TOKEN]
push_url = app_data[ATTR_PUSH_URL]
async def _async_send_remote_message_target(self, target, data):
"""Send a message to a target."""
entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target]
entry_data = entry.data
target_data = dict(data)
target_data[ATTR_PUSH_TOKEN] = push_token
app_data = entry_data[ATTR_APP_DATA]
push_token = app_data[ATTR_PUSH_TOKEN]
push_url = app_data[ATTR_PUSH_URL]
reg_info = {
ATTR_APP_ID: entry_data[ATTR_APP_ID],
ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION],
}
if ATTR_OS_VERSION in entry_data:
reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION]
target_data = dict(data)
target_data[ATTR_PUSH_TOKEN] = push_token
target_data["registration_info"] = reg_info
reg_info = {
ATTR_APP_ID: entry_data[ATTR_APP_ID],
ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION],
}
if ATTR_OS_VERSION in entry_data:
reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION]
try:
with async_timeout.timeout(10):
response = await async_get_clientsession(self._hass).post(
push_url, json=target_data
)
result = await response.json()
target_data["registration_info"] = reg_info
if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED):
log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result)
continue
fallback_error = result.get("errorMessage", "Unknown error")
fallback_message = (
f"Internal server error, please try again later: {fallback_error}"
try:
with async_timeout.timeout(10):
response = await async_get_clientsession(self._hass).post(
push_url, json=target_data
)
message = result.get("message", fallback_message)
result = await response.json()
if "message" in result:
if message[-1] not in [".", "?", "!"]:
message += "."
message += (
" This message is generated externally to Home Assistant."
)
if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED):
log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result)
return
if response.status == HTTP_TOO_MANY_REQUESTS:
_LOGGER.warning(message)
log_rate_limits(
self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING
)
else:
_LOGGER.error(message)
fallback_error = result.get("errorMessage", "Unknown error")
fallback_message = (
f"Internal server error, please try again later: {fallback_error}"
)
message = result.get("message", fallback_message)
except asyncio.TimeoutError:
_LOGGER.error("Timeout sending notification to %s", push_url)
except aiohttp.ClientError as err:
_LOGGER.error("Error sending notification to %s: %r", push_url, err)
if "message" in result:
if message[-1] not in [".", "?", "!"]:
message += "."
message += " This message is generated externally to Home Assistant."
if response.status == HTTP_TOO_MANY_REQUESTS:
_LOGGER.warning(message)
log_rate_limits(
self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING
)
else:
_LOGGER.error(message)
except asyncio.TimeoutError:
_LOGGER.error("Timeout sending notification to %s", push_url)
except aiohttp.ClientError as err:
_LOGGER.error("Error sending notification to %s: %r", push_url, err)

View File

@ -0,0 +1,90 @@
"""Push notification handling."""
import asyncio
from typing import Callable
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.util.uuid import random_uuid_hex
PUSH_CONFIRM_TIMEOUT = 10 # seconds
class PushChannel:
"""Class that represents a push channel."""
def __init__(
self,
hass: HomeAssistant,
webhook_id: str,
support_confirm: bool,
send_message: Callable[[dict], None],
on_teardown: Callable[[], None],
) -> None:
"""Initialize a local push channel."""
self.hass = hass
self.webhook_id = webhook_id
self.support_confirm = support_confirm
self._send_message = send_message
self.on_teardown = on_teardown
self.pending_confirms = {}
@callback
def async_send_notification(self, data, fallback_send):
"""Send a push notification."""
if not self.support_confirm:
self._send_message(data)
return
confirm_id = random_uuid_hex()
data["hass_confirm_id"] = confirm_id
async def handle_push_failed(_=None):
"""Handle a failed local push notification."""
# Remove this handler from the pending dict
# If it didn't exist we hit a race condition between call_later and another
# push failing and tearing down the connection.
if self.pending_confirms.pop(confirm_id, None) is None:
return
# Drop local channel if it's still open
if self.on_teardown is not None:
await self.async_teardown()
await fallback_send(data)
self.pending_confirms[confirm_id] = {
"unsub_scheduled_push_failed": async_call_later(
self.hass, PUSH_CONFIRM_TIMEOUT, handle_push_failed
),
"handle_push_failed": handle_push_failed,
}
self._send_message(data)
@callback
def async_confirm_notification(self, confirm_id) -> bool:
"""Confirm a push notification.
Returns if confirmation successful.
"""
if confirm_id not in self.pending_confirms:
return False
self.pending_confirms.pop(confirm_id)["unsub_scheduled_push_failed"]()
return True
async def async_teardown(self):
"""Tear down this channel."""
# Tear down is in progress
if self.on_teardown is None:
return
self.on_teardown()
self.on_teardown = None
cancel_pending_local_tasks = [
actions["handle_push_failed"]()
for actions in self.pending_confirms.values()
]
if cancel_pending_local_tasks:
await asyncio.gather(*cancel_pending_local_tasks)

View File

@ -0,0 +1,121 @@
"""Mobile app websocket API."""
from __future__ import annotations
from functools import wraps
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import callback
from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
from .push_notification import PushChannel
@callback
def async_setup_commands(hass):
"""Set up the mobile app websocket API."""
websocket_api.async_register_command(hass, handle_push_notification_channel)
websocket_api.async_register_command(hass, handle_push_notification_confirm)
def _ensure_webhook_access(func):
"""Decorate WS function to ensure user owns the webhook ID."""
@callback
@wraps(func)
def with_webhook_access(hass, connection, msg):
# Validate that the webhook ID is registered to the user of the websocket connection
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])
if config_entry is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
)
return
if config_entry.data[CONF_USER_ID] != connection.user.id:
connection.send_error(
msg["id"],
websocket_api.ERR_UNAUTHORIZED,
"User not linked to this webhook ID",
)
return
func(hass, connection, msg)
return with_webhook_access
@callback
@_ensure_webhook_access
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_confirm",
vol.Required("webhook_id"): str,
vol.Required("confirm_id"): str,
}
)
def handle_push_notification_confirm(hass, connection, msg):
"""Confirm receipt of a push notification."""
channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
msg["webhook_id"]
)
if channel is None:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_FOUND,
"Push notification channel not found",
)
return
if channel.async_confirm_notification(msg["confirm_id"]):
connection.send_result(msg["id"])
else:
connection.send_error(
msg["id"],
websocket_api.ERR_NOT_FOUND,
"Push notification channel not found",
)
@websocket_api.websocket_command(
{
vol.Required("type"): "mobile_app/push_notification_channel",
vol.Required("webhook_id"): str,
vol.Optional("support_confirm", default=False): bool,
}
)
@_ensure_webhook_access
@websocket_api.async_response
async def handle_push_notification_channel(hass, connection, msg):
"""Set up a direct push notification channel."""
webhook_id = msg["webhook_id"]
registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
if webhook_id in registered_channels:
await registered_channels[webhook_id].async_teardown()
@callback
def on_channel_teardown():
"""Handle teardown."""
if registered_channels.get(webhook_id) == channel:
registered_channels.pop(webhook_id)
# Remove subscription from connection if still exists
connection.subscriptions.pop(msg["id"], None)
channel = registered_channels[webhook_id] = PushChannel(
hass,
webhook_id,
msg["support_confirm"],
lambda data: connection.send_message(
websocket_api.messages.event_message(msg["id"], data)
),
on_channel_teardown,
)
connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
channel.async_teardown()
)
connection.send_result(msg["id"])

View File

@ -104,8 +104,8 @@ class ActiveConnection:
self.last_id = cur_id
@callback
def async_close(self) -> None:
"""Close down connection."""
def async_handle_close(self) -> None:
"""Handle closing down connection."""
for unsub in self.subscriptions.values():
unsub()

View File

@ -231,7 +231,7 @@ class WebSocketHandler:
unsub_stop()
if connection is not None:
connection.async_close()
connection.async_handle_close()
try:
self._to_write.put_nowait(None)

View File

@ -1,5 +1,6 @@
"""Notify platform tests for mobile_app."""
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
@ -204,3 +205,130 @@ async def test_notify_ws_works(
"code": "unauthorized",
"message": "User not linked to this webhook ID",
}
async def test_notify_ws_confirming_works(
hass, aioclient_mock, setup_push_receiver, hass_ws_client
):
"""Test notify confirming works."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "mobile_app/push_notification_channel",
"webhook_id": "mock-webhook_id",
"support_confirm": True,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
# Sent a message that will be delivered locally
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world"}, blocking=True
)
msg_result = await client.receive_json()
confirm_id = msg_result["event"].pop("hass_confirm_id")
assert confirm_id is not None
assert msg_result["event"] == {"message": "Hello world"}
# Try to confirm with incorrect confirm ID
await client.send_json(
{
"id": 6,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": "incorrect-confirm-id",
}
)
result = await client.receive_json()
assert not result["success"]
assert result["error"] == {
"code": "not_found",
"message": "Push notification channel not found",
}
# Confirm with correct confirm ID
await client.send_json(
{
"id": 7,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": confirm_id,
}
)
result = await client.receive_json()
assert result["success"]
# Drop local push channel and try to confirm another message
await client.send_json(
{
"id": 8,
"type": "unsubscribe_events",
"subscription": 5,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
await client.send_json(
{
"id": 9,
"type": "mobile_app/push_notification_confirm",
"webhook_id": "mock-webhook_id",
"confirm_id": confirm_id,
}
)
result = await client.receive_json()
assert not result["success"]
assert result["error"] == {
"code": "not_found",
"message": "Push notification channel not found",
}
async def test_notify_ws_not_confirming(
hass, aioclient_mock, setup_push_receiver, hass_ws_client
):
"""Test we go via cloud when failed to confirm."""
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 5,
"type": "mobile_app/push_notification_channel",
"webhook_id": "mock-webhook_id",
"support_confirm": True,
}
)
sub_result = await client.receive_json()
assert sub_result["success"]
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 1"}, blocking=True
)
with patch(
"homeassistant.components.mobile_app.push_notification.PUSH_CONFIRM_TIMEOUT", 0
):
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 2"}, blocking=True
)
await hass.async_block_till_done()
# When we fail, all unconfirmed ones and failed one are sent via cloud
assert len(aioclient_mock.mock_calls) == 2
# All future ones also go via cloud
assert await hass.services.async_call(
"notify", "mobile_app_test", {"message": "Hello world 3"}, blocking=True
)
assert len(aioclient_mock.mock_calls) == 3