122 lines
3.7 KiB
Python
122 lines
3.7 KiB
Python
|
"""Mobile app websocket API."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from functools import wraps
|
||
|
|
||
|
import voluptuous as vol
|
||
|
|
||
|
from homeassistant.components import websocket_api
|
||
|
from homeassistant.core import callback
|
||
|
|
||
|
from .const import CONF_USER_ID, DATA_CONFIG_ENTRIES, DATA_PUSH_CHANNEL, DOMAIN
|
||
|
from .push_notification import PushChannel
|
||
|
|
||
|
|
||
|
@callback
|
||
|
def async_setup_commands(hass):
|
||
|
"""Set up the mobile app websocket API."""
|
||
|
websocket_api.async_register_command(hass, handle_push_notification_channel)
|
||
|
websocket_api.async_register_command(hass, handle_push_notification_confirm)
|
||
|
|
||
|
|
||
|
def _ensure_webhook_access(func):
|
||
|
"""Decorate WS function to ensure user owns the webhook ID."""
|
||
|
|
||
|
@callback
|
||
|
@wraps(func)
|
||
|
def with_webhook_access(hass, connection, msg):
|
||
|
# Validate that the webhook ID is registered to the user of the websocket connection
|
||
|
config_entry = hass.data[DOMAIN][DATA_CONFIG_ENTRIES].get(msg["webhook_id"])
|
||
|
|
||
|
if config_entry is None:
|
||
|
connection.send_error(
|
||
|
msg["id"], websocket_api.ERR_NOT_FOUND, "Webhook ID not found"
|
||
|
)
|
||
|
return
|
||
|
|
||
|
if config_entry.data[CONF_USER_ID] != connection.user.id:
|
||
|
connection.send_error(
|
||
|
msg["id"],
|
||
|
websocket_api.ERR_UNAUTHORIZED,
|
||
|
"User not linked to this webhook ID",
|
||
|
)
|
||
|
return
|
||
|
|
||
|
func(hass, connection, msg)
|
||
|
|
||
|
return with_webhook_access
|
||
|
|
||
|
|
||
|
@callback
|
||
|
@_ensure_webhook_access
|
||
|
@websocket_api.websocket_command(
|
||
|
{
|
||
|
vol.Required("type"): "mobile_app/push_notification_confirm",
|
||
|
vol.Required("webhook_id"): str,
|
||
|
vol.Required("confirm_id"): str,
|
||
|
}
|
||
|
)
|
||
|
def handle_push_notification_confirm(hass, connection, msg):
|
||
|
"""Confirm receipt of a push notification."""
|
||
|
channel: PushChannel | None = hass.data[DOMAIN][DATA_PUSH_CHANNEL].get(
|
||
|
msg["webhook_id"]
|
||
|
)
|
||
|
if channel is None:
|
||
|
connection.send_error(
|
||
|
msg["id"],
|
||
|
websocket_api.ERR_NOT_FOUND,
|
||
|
"Push notification channel not found",
|
||
|
)
|
||
|
return
|
||
|
|
||
|
if channel.async_confirm_notification(msg["confirm_id"]):
|
||
|
connection.send_result(msg["id"])
|
||
|
else:
|
||
|
connection.send_error(
|
||
|
msg["id"],
|
||
|
websocket_api.ERR_NOT_FOUND,
|
||
|
"Push notification channel not found",
|
||
|
)
|
||
|
|
||
|
|
||
|
@websocket_api.websocket_command(
|
||
|
{
|
||
|
vol.Required("type"): "mobile_app/push_notification_channel",
|
||
|
vol.Required("webhook_id"): str,
|
||
|
vol.Optional("support_confirm", default=False): bool,
|
||
|
}
|
||
|
)
|
||
|
@_ensure_webhook_access
|
||
|
@websocket_api.async_response
|
||
|
async def handle_push_notification_channel(hass, connection, msg):
|
||
|
"""Set up a direct push notification channel."""
|
||
|
webhook_id = msg["webhook_id"]
|
||
|
registered_channels: dict[str, PushChannel] = hass.data[DOMAIN][DATA_PUSH_CHANNEL]
|
||
|
|
||
|
if webhook_id in registered_channels:
|
||
|
await registered_channels[webhook_id].async_teardown()
|
||
|
|
||
|
@callback
|
||
|
def on_channel_teardown():
|
||
|
"""Handle teardown."""
|
||
|
if registered_channels.get(webhook_id) == channel:
|
||
|
registered_channels.pop(webhook_id)
|
||
|
|
||
|
# Remove subscription from connection if still exists
|
||
|
connection.subscriptions.pop(msg["id"], None)
|
||
|
|
||
|
channel = registered_channels[webhook_id] = PushChannel(
|
||
|
hass,
|
||
|
webhook_id,
|
||
|
msg["support_confirm"],
|
||
|
lambda data: connection.send_message(
|
||
|
websocket_api.messages.event_message(msg["id"], data)
|
||
|
),
|
||
|
on_channel_teardown,
|
||
|
)
|
||
|
|
||
|
connection.subscriptions[msg["id"]] = lambda: hass.async_create_task(
|
||
|
channel.async_teardown()
|
||
|
)
|
||
|
connection.send_result(msg["id"])
|