Allow confirming local push notifications (#54947)
* Allow confirming local push notifications * Fix from Zac * Add testspull/56546/head
parent
f77e93ceeb
commit
677abcd484
|
@ -1,25 +1,23 @@
|
|||
"""Integrates Native Apps to Home Assistant."""
|
||||
from contextlib import suppress
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import cloud, notify as hass_notify, websocket_api
|
||||
from homeassistant.components import cloud, notify as hass_notify
|
||||
from homeassistant.components.webhook import (
|
||||
async_register as webhook_register,
|
||||
async_unregister as webhook_unregister,
|
||||
)
|
||||
from homeassistant.const import ATTR_DEVICE_ID, CONF_WEBHOOK_ID
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, discovery
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import websocket_api
|
||||
from .const import (
|
||||
ATTR_DEVICE_NAME,
|
||||
ATTR_MANUFACTURER,
|
||||
ATTR_MODEL,
|
||||
ATTR_OS_VERSION,
|
||||
CONF_CLOUDHOOK_URL,
|
||||
CONF_USER_ID,
|
||||
DATA_CONFIG_ENTRIES,
|
||||
DATA_DELETED_IDS,
|
||||
DATA_DEVICES,
|
||||
|
@ -66,7 +64,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
discovery.async_load_platform(hass, "notify", DOMAIN, {}, config)
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(hass, handle_push_notification_channel)
|
||||
websocket_api.async_setup_commands(hass)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -127,52 +125,3 @@ async def async_remove_entry(hass, entry):
|
|||
if CONF_CLOUDHOOK_URL in entry.data:
|
||||
with suppress(cloud.CloudNotAvailable):
|
||||
await cloud.async_delete_cloudhook(hass, entry.data[CONF_WEBHOOK_ID])
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mobile_app/push_notification_channel",
|
||||
vol.Required("webhook_id"): str,
|
||||
}
|
||||
)
|
||||
def handle_push_notification_channel(hass, connection, msg):
|
||||
"""Set up a direct push notification channel."""
|
||||
webhook_id = msg["webhook_id"]
|
||||
|
||||
# Validate that the webhook ID is registered to the user of the websocket connection
|
||||
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(webhook_id)
|
||||
|
||||
if config_entry is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
|
||||
)
|
||||
return
|
||||
|
||||
if config_entry.data[CONF_USER_ID] != connection.user.id:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_UNAUTHORIZED,
|
||||
"User not linked to this webhook ID",
|
||||
)
|
||||
return
|
||||
|
||||
registered_channels = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
|
||||
|
||||
if webhook_id in registered_channels:
|
||||
registered_channels.pop(webhook_id)
|
||||
|
||||
@callback
|
||||
def forward_push_notification(data):
|
||||
"""Forward events to websocket."""
|
||||
connection.send_message(websocket_api.messages.event_message(msg["id"], data))
|
||||
|
||||
@callback
|
||||
def unsub():
|
||||
# pylint: disable=comparison-with-callable
|
||||
if registered_channels.get(webhook_id) == forward_push_notification:
|
||||
registered_channels.pop(webhook_id)
|
||||
|
||||
registered_channels[webhook_id] = forward_push_notification
|
||||
connection.subscriptions[msg["id"]] = unsub
|
||||
connection.send_result(msg["id"])
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Support for mobile_app push notifications."""
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
import aiohttp
|
||||
|
@ -124,61 +125,65 @@ class MobileAppNotificationService(BaseNotificationService):
|
|||
|
||||
for target in targets:
|
||||
if target in local_push_channels:
|
||||
local_push_channels[target](data)
|
||||
local_push_channels[target].async_send_notification(
|
||||
data, partial(self._async_send_remote_message_target, target)
|
||||
)
|
||||
continue
|
||||
|
||||
entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target]
|
||||
entry_data = entry.data
|
||||
await self._async_send_remote_message_target(target, data)
|
||||
|
||||
app_data = entry_data[ATTR_APP_DATA]
|
||||
push_token = app_data[ATTR_PUSH_TOKEN]
|
||||
push_url = app_data[ATTR_PUSH_URL]
|
||||
async def _async_send_remote_message_target(self, target, data):
|
||||
"""Send a message to a target."""
|
||||
entry = self.hass.data[DOMAIN][DATA_CONFIG_ENTRIES][target]
|
||||
entry_data = entry.data
|
||||
|
||||
target_data = dict(data)
|
||||
target_data[ATTR_PUSH_TOKEN] = push_token
|
||||
app_data = entry_data[ATTR_APP_DATA]
|
||||
push_token = app_data[ATTR_PUSH_TOKEN]
|
||||
push_url = app_data[ATTR_PUSH_URL]
|
||||
|
||||
reg_info = {
|
||||
ATTR_APP_ID: entry_data[ATTR_APP_ID],
|
||||
ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION],
|
||||
}
|
||||
if ATTR_OS_VERSION in entry_data:
|
||||
reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION]
|
||||
target_data = dict(data)
|
||||
target_data[ATTR_PUSH_TOKEN] = push_token
|
||||
|
||||
target_data["registration_info"] = reg_info
|
||||
reg_info = {
|
||||
ATTR_APP_ID: entry_data[ATTR_APP_ID],
|
||||
ATTR_APP_VERSION: entry_data[ATTR_APP_VERSION],
|
||||
}
|
||||
if ATTR_OS_VERSION in entry_data:
|
||||
reg_info[ATTR_OS_VERSION] = entry_data[ATTR_OS_VERSION]
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
response = await async_get_clientsession(self._hass).post(
|
||||
push_url, json=target_data
|
||||
)
|
||||
result = await response.json()
|
||||
target_data["registration_info"] = reg_info
|
||||
|
||||
if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED):
|
||||
log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result)
|
||||
continue
|
||||
|
||||
fallback_error = result.get("errorMessage", "Unknown error")
|
||||
fallback_message = (
|
||||
f"Internal server error, please try again later: {fallback_error}"
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
response = await async_get_clientsession(self._hass).post(
|
||||
push_url, json=target_data
|
||||
)
|
||||
message = result.get("message", fallback_message)
|
||||
result = await response.json()
|
||||
|
||||
if "message" in result:
|
||||
if message[-1] not in [".", "?", "!"]:
|
||||
message += "."
|
||||
message += (
|
||||
" This message is generated externally to Home Assistant."
|
||||
)
|
||||
if response.status in (HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED):
|
||||
log_rate_limits(self.hass, entry_data[ATTR_DEVICE_NAME], result)
|
||||
return
|
||||
|
||||
if response.status == HTTP_TOO_MANY_REQUESTS:
|
||||
_LOGGER.warning(message)
|
||||
log_rate_limits(
|
||||
self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING
|
||||
)
|
||||
else:
|
||||
_LOGGER.error(message)
|
||||
fallback_error = result.get("errorMessage", "Unknown error")
|
||||
fallback_message = (
|
||||
f"Internal server error, please try again later: {fallback_error}"
|
||||
)
|
||||
message = result.get("message", fallback_message)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.error("Timeout sending notification to %s", push_url)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Error sending notification to %s: %r", push_url, err)
|
||||
if "message" in result:
|
||||
if message[-1] not in [".", "?", "!"]:
|
||||
message += "."
|
||||
message += " This message is generated externally to Home Assistant."
|
||||
|
||||
if response.status == HTTP_TOO_MANY_REQUESTS:
|
||||
_LOGGER.warning(message)
|
||||
log_rate_limits(
|
||||
self.hass, entry_data[ATTR_DEVICE_NAME], result, logging.WARNING
|
||||
)
|
||||
else:
|
||||
_LOGGER.error(message)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.error("Timeout sending notification to %s", push_url)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Error sending notification to %s: %r", push_url, err)
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
"""Push notification handling."""
|
||||
import asyncio
|
||||
from typing import Callable
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.util.uuid import random_uuid_hex
|
||||
|
||||
PUSH_CONFIRM_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class PushChannel:
|
||||
"""Class that represents a push channel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
webhook_id: str,
|
||||
support_confirm: bool,
|
||||
send_message: Callable[[dict], None],
|
||||
on_teardown: Callable[[], None],
|
||||
) -> None:
|
||||
"""Initialize a local push channel."""
|
||||
self.hass = hass
|
||||
self.webhook_id = webhook_id
|
||||
self.support_confirm = support_confirm
|
||||
self._send_message = send_message
|
||||
self.on_teardown = on_teardown
|
||||
self.pending_confirms = {}
|
||||
|
||||
@callback
|
||||
def async_send_notification(self, data, fallback_send):
|
||||
"""Send a push notification."""
|
||||
if not self.support_confirm:
|
||||
self._send_message(data)
|
||||
return
|
||||
|
||||
confirm_id = random_uuid_hex()
|
||||
data["hass_confirm_id"] = confirm_id
|
||||
|
||||
async def handle_push_failed(_=None):
|
||||
"""Handle a failed local push notification."""
|
||||
# Remove this handler from the pending dict
|
||||
# If it didn't exist we hit a race condition between call_later and another
|
||||
# push failing and tearing down the connection.
|
||||
if self.pending_confirms.pop(confirm_id, None) is None:
|
||||
return
|
||||
|
||||
# Drop local channel if it's still open
|
||||
if self.on_teardown is not None:
|
||||
await self.async_teardown()
|
||||
|
||||
await fallback_send(data)
|
||||
|
||||
self.pending_confirms[confirm_id] = {
|
||||
"unsub_scheduled_push_failed": async_call_later(
|
||||
self.hass, PUSH_CONFIRM_TIMEOUT, handle_push_failed
|
||||
),
|
||||
"handle_push_failed": handle_push_failed,
|
||||
}
|
||||
self._send_message(data)
|
||||
|
||||
@callback
|
||||
def async_confirm_notification(self, confirm_id) -> bool:
|
||||
"""Confirm a push notification.
|
||||
|
||||
Returns if confirmation successful.
|
||||
"""
|
||||
if confirm_id not in self.pending_confirms:
|
||||
return False
|
||||
|
||||
self.pending_confirms.pop(confirm_id)["unsub_scheduled_push_failed"]()
|
||||
return True
|
||||
|
||||
async def async_teardown(self):
|
||||
"""Tear down this channel."""
|
||||
# Tear down is in progress
|
||||
if self.on_teardown is None:
|
||||
return
|
||||
|
||||
self.on_teardown()
|
||||
self.on_teardown = None
|
||||
|
||||
cancel_pending_local_tasks = [
|
||||
actions["handle_push_failed"]()
|
||||
for actions in self.pending_confirms.values()
|
||||
]
|
||||
|
||||
if cancel_pending_local_tasks:
|
||||
await asyncio.gather(*cancel_pending_local_tasks)
|
|
@ -0,0 +1,121 @@
|
|||
"""Mobile app websocket API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
|
||||
from .push_notification import PushChannel
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup_commands(hass):
|
||||
"""Set up the mobile app websocket API."""
|
||||
websocket_api.async_register_command(hass, handle_push_notification_channel)
|
||||
websocket_api.async_register_command(hass, handle_push_notification_confirm)
|
||||
|
||||
|
||||
def _ensure_webhook_access(func):
|
||||
"""Decorate WS function to ensure user owns the webhook ID."""
|
||||
|
||||
@callback
|
||||
@wraps(func)
|
||||
def with_webhook_access(hass, connection, msg):
|
||||
# Validate that the webhook ID is registered to the user of the websocket connection
|
||||
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])
|
||||
|
||||
if config_entry is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
|
||||
)
|
||||
return
|
||||
|
||||
if config_entry.data[CONF_USER_ID] != connection.user.id:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_UNAUTHORIZED,
|
||||
"User not linked to this webhook ID",
|
||||
)
|
||||
return
|
||||
|
||||
func(hass, connection, msg)
|
||||
|
||||
return with_webhook_access
|
||||
|
||||
|
||||
@callback
|
||||
@_ensure_webhook_access
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mobile_app/push_notification_confirm",
|
||||
vol.Required("webhook_id"): str,
|
||||
vol.Required("confirm_id"): str,
|
||||
}
|
||||
)
|
||||
def handle_push_notification_confirm(hass, connection, msg):
|
||||
"""Confirm receipt of a push notification."""
|
||||
channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
|
||||
msg["webhook_id"]
|
||||
)
|
||||
if channel is None:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_NOT_FOUND,
|
||||
"Push notification channel not found",
|
||||
)
|
||||
return
|
||||
|
||||
if channel.async_confirm_notification(msg["confirm_id"]):
|
||||
connection.send_result(msg["id"])
|
||||
else:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.ERR_NOT_FOUND,
|
||||
"Push notification channel not found",
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mobile_app/push_notification_channel",
|
||||
vol.Required("webhook_id"): str,
|
||||
vol.Optional("support_confirm", default=False): bool,
|
||||
}
|
||||
)
|
||||
@_ensure_webhook_access
|
||||
@websocket_api.async_response
|
||||
async def handle_push_notification_channel(hass, connection, msg):
|
||||
"""Set up a direct push notification channel."""
|
||||
webhook_id = msg["webhook_id"]
|
||||
registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
|
||||
|
||||
if webhook_id in registered_channels:
|
||||
await registered_channels[webhook_id].async_teardown()
|
||||
|
||||
@callback
|
||||
def on_channel_teardown():
|
||||
"""Handle teardown."""
|
||||
if registered_channels.get(webhook_id) == channel:
|
||||
registered_channels.pop(webhook_id)
|
||||
|
||||
# Remove subscription from connection if still exists
|
||||
connection.subscriptions.pop(msg["id"], None)
|
||||
|
||||
channel = registered_channels[webhook_id] = PushChannel(
|
||||
hass,
|
||||
webhook_id,
|
||||
msg["support_confirm"],
|
||||
lambda data: connection.send_message(
|
||||
websocket_api.messages.event_message(msg["id"], data)
|
||||
),
|
||||
on_channel_teardown,
|
||||
)
|
||||
|
||||
connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
|
||||
channel.async_teardown()
|
||||
)
|
||||
connection.send_result(msg["id"])
|
|
@ -104,8 +104,8 @@ class ActiveConnection:
|
|||
self.last_id = cur_id
|
||||
|
||||
@callback
|
||||
def async_close(self) -> None:
|
||||
"""Close down connection."""
|
||||
def async_handle_close(self) -> None:
|
||||
"""Handle closing down connection."""
|
||||
for unsub in self.subscriptions.values():
|
||||
unsub()
|
||||
|
||||
|
|
|
@ -231,7 +231,7 @@ class WebSocketHandler:
|
|||
unsub_stop()
|
||||
|
||||
if connection is not None:
|
||||
connection.async_close()
|
||||
connection.async_handle_close()
|
||||
|
||||
try:
|
||||
self._to_write.put_nowait(None)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Notify platform tests for mobile_app."""
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -204,3 +205,130 @@ async def test_notify_ws_works(
|
|||
"code": "unauthorized",
|
||||
"message": "User not linked to this webhook ID",
|
||||
}
|
||||
|
||||
|
||||
async def test_notify_ws_confirming_works(
|
||||
hass, aioclient_mock, setup_push_receiver, hass_ws_client
|
||||
):
|
||||
"""Test notify confirming works."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "mobile_app/push_notification_channel",
|
||||
"webhook_id": "mock-webhook_id",
|
||||
"support_confirm": True,
|
||||
}
|
||||
)
|
||||
|
||||
sub_result = await client.receive_json()
|
||||
assert sub_result["success"]
|
||||
|
||||
# Sent a message that will be delivered locally
|
||||
assert await hass.services.async_call(
|
||||
"notify", "mobile_app_test", {"message": "Hello world"}, blocking=True
|
||||
)
|
||||
|
||||
msg_result = await client.receive_json()
|
||||
confirm_id = msg_result["event"].pop("hass_confirm_id")
|
||||
assert confirm_id is not None
|
||||
assert msg_result["event"] == {"message": "Hello world"}
|
||||
|
||||
# Try to confirm with incorrect confirm ID
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 6,
|
||||
"type": "mobile_app/push_notification_confirm",
|
||||
"webhook_id": "mock-webhook_id",
|
||||
"confirm_id": "incorrect-confirm-id",
|
||||
}
|
||||
)
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result["success"]
|
||||
assert result["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Push notification channel not found",
|
||||
}
|
||||
|
||||
# Confirm with correct confirm ID
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 7,
|
||||
"type": "mobile_app/push_notification_confirm",
|
||||
"webhook_id": "mock-webhook_id",
|
||||
"confirm_id": confirm_id,
|
||||
}
|
||||
)
|
||||
|
||||
result = await client.receive_json()
|
||||
assert result["success"]
|
||||
|
||||
# Drop local push channel and try to confirm another message
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 8,
|
||||
"type": "unsubscribe_events",
|
||||
"subscription": 5,
|
||||
}
|
||||
)
|
||||
sub_result = await client.receive_json()
|
||||
assert sub_result["success"]
|
||||
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 9,
|
||||
"type": "mobile_app/push_notification_confirm",
|
||||
"webhook_id": "mock-webhook_id",
|
||||
"confirm_id": confirm_id,
|
||||
}
|
||||
)
|
||||
|
||||
result = await client.receive_json()
|
||||
assert not result["success"]
|
||||
assert result["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Push notification channel not found",
|
||||
}
|
||||
|
||||
|
||||
async def test_notify_ws_not_confirming(
|
||||
hass, aioclient_mock, setup_push_receiver, hass_ws_client
|
||||
):
|
||||
"""Test we go via cloud when failed to confirm."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "mobile_app/push_notification_channel",
|
||||
"webhook_id": "mock-webhook_id",
|
||||
"support_confirm": True,
|
||||
}
|
||||
)
|
||||
|
||||
sub_result = await client.receive_json()
|
||||
assert sub_result["success"]
|
||||
|
||||
assert await hass.services.async_call(
|
||||
"notify", "mobile_app_test", {"message": "Hello world 1"}, blocking=True
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.mobile_app.push_notification.PUSH_CONFIRM_TIMEOUT", 0
|
||||
):
|
||||
assert await hass.services.async_call(
|
||||
"notify", "mobile_app_test", {"message": "Hello world 2"}, blocking=True
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# When we fail, all unconfirmed ones and failed one are sent via cloud
|
||||
assert len(aioclient_mock.mock_calls) == 2
|
||||
|
||||
# All future ones also go via cloud
|
||||
assert await hass.services.async_call(
|
||||
"notify", "mobile_app_test", {"message": "Hello world 3"}, blocking=True
|
||||
)
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 3
|
||||
|
|
Loading…
Reference in New Issue