Store notifications in component. Add ws endpoint for fetching. (#16503)
* Store notifications in component. Add ws endpoint for fetching. * Commentspull/16556/head
parent
20f6cb7cc7
commit
50fb59477a
|
@ -10,7 +10,6 @@ from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
@ -92,9 +91,10 @@ async def process_wrong_login(request):
|
|||
msg = ('Login attempt or request with invalid authentication '
|
||||
'from {}'.format(remote_addr))
|
||||
_LOGGER.warning(msg)
|
||||
persistent_notification.async_create(
|
||||
request.app['hass'], msg, 'Login attempt failed',
|
||||
NOTIFICATION_ID_LOGIN)
|
||||
|
||||
hass = request.app['hass']
|
||||
hass.components.persistent_notification.async_create(
|
||||
msg, 'Login attempt failed', NOTIFICATION_ID_LOGIN)
|
||||
|
||||
# Check if ban middleware is loaded
|
||||
if (KEY_BANNED_IPS not in request.app or
|
||||
|
@ -108,15 +108,13 @@ async def process_wrong_login(request):
|
|||
new_ban = IpBan(remote_addr)
|
||||
request.app[KEY_BANNED_IPS].append(new_ban)
|
||||
|
||||
hass = request.app['hass']
|
||||
await hass.async_add_job(
|
||||
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
|
||||
|
||||
_LOGGER.warning(
|
||||
"Banned IP %s for too many login attempts", remote_addr)
|
||||
|
||||
persistent_notification.async_create(
|
||||
hass,
|
||||
hass.components.persistent_notification.async_create(
|
||||
'Too many login attempts from {}'.format(remote_addr),
|
||||
'Banning IP address', NOTIFICATION_ID_BAN)
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@ https://home-assistant.io/components/persistent_notification/
|
|||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Awaitable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.exceptions import TemplateError
|
||||
from homeassistant.loader import bind_hass
|
||||
|
@ -20,13 +22,17 @@ from homeassistant.util import slugify
|
|||
ATTR_MESSAGE = 'message'
|
||||
ATTR_NOTIFICATION_ID = 'notification_id'
|
||||
ATTR_TITLE = 'title'
|
||||
ATTR_STATUS = 'status'
|
||||
|
||||
DOMAIN = 'persistent_notification'
|
||||
|
||||
ENTITY_ID_FORMAT = DOMAIN + '.{}'
|
||||
|
||||
EVENT_PERSISTENT_NOTIFICATIONS_UPDATED = 'persistent_notifications_updated'
|
||||
|
||||
SERVICE_CREATE = 'create'
|
||||
SERVICE_DISMISS = 'dismiss'
|
||||
SERVICE_MARK_READ = 'mark_read'
|
||||
|
||||
SCHEMA_SERVICE_CREATE = vol.Schema({
|
||||
vol.Required(ATTR_MESSAGE): cv.template,
|
||||
|
@ -38,11 +44,21 @@ SCHEMA_SERVICE_DISMISS = vol.Schema({
|
|||
vol.Required(ATTR_NOTIFICATION_ID): cv.string,
|
||||
})
|
||||
|
||||
SCHEMA_SERVICE_MARK_READ = vol.Schema({
|
||||
vol.Required(ATTR_NOTIFICATION_ID): cv.string,
|
||||
})
|
||||
|
||||
DEFAULT_OBJECT_ID = 'notification'
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STATE = 'notifying'
|
||||
STATUS_UNREAD = 'unread'
|
||||
STATUS_READ = 'read'
|
||||
|
||||
WS_TYPE_GET_NOTIFICATIONS = 'persistent_notification/get'
|
||||
SCHEMA_WS_GET = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): WS_TYPE_GET_NOTIFICATIONS,
|
||||
})
|
||||
|
||||
|
||||
@bind_hass
|
||||
|
@ -76,7 +92,7 @@ def async_create(hass: HomeAssistant, message: str, title: str = None,
|
|||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_dismiss(hass, notification_id):
|
||||
def async_dismiss(hass: HomeAssistant, notification_id: str) -> None:
|
||||
"""Remove a notification."""
|
||||
data = {ATTR_NOTIFICATION_ID: notification_id}
|
||||
|
||||
|
@ -86,6 +102,9 @@ def async_dismiss(hass, notification_id):
|
|||
@asyncio.coroutine
|
||||
def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
|
||||
"""Set up the persistent notification component."""
|
||||
persistent_notifications = OrderedDict()
|
||||
hass.data[DOMAIN] = {'notifications': persistent_notifications}
|
||||
|
||||
@callback
|
||||
def create_service(call):
|
||||
"""Handle a create notification service call."""
|
||||
|
@ -98,6 +117,8 @@ def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
|
|||
else:
|
||||
entity_id = async_generate_entity_id(
|
||||
ENTITY_ID_FORMAT, DEFAULT_OBJECT_ID, hass=hass)
|
||||
notification_id = entity_id.split('.')[1]
|
||||
|
||||
attr = {}
|
||||
if title is not None:
|
||||
try:
|
||||
|
@ -120,18 +141,72 @@ def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
|
|||
|
||||
hass.states.async_set(entity_id, STATE, attr)
|
||||
|
||||
# Store notification and fire event
|
||||
# This will eventually replace state machine storage
|
||||
persistent_notifications[entity_id] = {
|
||||
ATTR_MESSAGE: message,
|
||||
ATTR_NOTIFICATION_ID: notification_id,
|
||||
ATTR_STATUS: STATUS_UNREAD,
|
||||
ATTR_TITLE: title,
|
||||
}
|
||||
|
||||
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
|
||||
|
||||
@callback
|
||||
def dismiss_service(call):
|
||||
"""Handle the dismiss notification service call."""
|
||||
notification_id = call.data.get(ATTR_NOTIFICATION_ID)
|
||||
entity_id = ENTITY_ID_FORMAT.format(slugify(notification_id))
|
||||
|
||||
if entity_id not in persistent_notifications:
|
||||
return
|
||||
|
||||
hass.states.async_remove(entity_id)
|
||||
|
||||
del persistent_notifications[entity_id]
|
||||
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
|
||||
|
||||
@callback
|
||||
def mark_read_service(call):
|
||||
"""Handle the mark_read notification service call."""
|
||||
notification_id = call.data.get(ATTR_NOTIFICATION_ID)
|
||||
entity_id = ENTITY_ID_FORMAT.format(slugify(notification_id))
|
||||
|
||||
if entity_id not in persistent_notifications:
|
||||
_LOGGER.error('Marking persistent_notification read failed: '
|
||||
'Notification ID %s not found.', notification_id)
|
||||
return
|
||||
|
||||
persistent_notifications[entity_id][ATTR_STATUS] = STATUS_READ
|
||||
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_CREATE, create_service,
|
||||
SCHEMA_SERVICE_CREATE)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_DISMISS, dismiss_service,
|
||||
SCHEMA_SERVICE_DISMISS)
|
||||
|
||||
hass.services.async_register(DOMAIN, SERVICE_MARK_READ, mark_read_service,
|
||||
SCHEMA_SERVICE_MARK_READ)
|
||||
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_GET_NOTIFICATIONS, websocket_get_notifications,
|
||||
SCHEMA_WS_GET
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_get_notifications(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
|
||||
"""Return a list of persistent_notifications."""
|
||||
connection.to_write.put_nowait(
|
||||
websocket_api.result_message(msg['id'], [
|
||||
{
|
||||
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,
|
||||
ATTR_STATUS, ATTR_TITLE)
|
||||
}
|
||||
for data in hass.data[DOMAIN]['notifications'].values()
|
||||
])
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""The tests for the persistent notification component."""
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
import homeassistant.components.persistent_notification as pn
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
@ -19,7 +20,9 @@ class TestPersistentNotification:
|
|||
|
||||
def test_create(self):
|
||||
"""Test creating notification without title or notification id."""
|
||||
notifications = self.hass.data[pn.DOMAIN]['notifications']
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
|
||||
assert len(notifications) == 0
|
||||
|
||||
pn.create(self.hass, 'Hello World {{ 1 + 1 }}',
|
||||
title='{{ 1 + 1 }} beers')
|
||||
|
@ -27,54 +30,170 @@ class TestPersistentNotification:
|
|||
|
||||
entity_ids = self.hass.states.entity_ids(pn.DOMAIN)
|
||||
assert len(entity_ids) == 1
|
||||
assert len(notifications) == 1
|
||||
|
||||
state = self.hass.states.get(entity_ids[0])
|
||||
assert state.state == pn.STATE
|
||||
assert state.attributes.get('message') == 'Hello World 2'
|
||||
assert state.attributes.get('title') == '2 beers'
|
||||
|
||||
notification = notifications.get(entity_ids[0])
|
||||
assert notification['status'] == pn.STATUS_UNREAD
|
||||
assert notification['message'] == 'Hello World 2'
|
||||
assert notification['title'] == '2 beers'
|
||||
notifications.clear()
|
||||
|
||||
def test_create_notification_id(self):
|
||||
"""Ensure overwrites existing notification with same id."""
|
||||
notifications = self.hass.data[pn.DOMAIN]['notifications']
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
|
||||
assert len(notifications) == 0
|
||||
|
||||
pn.create(self.hass, 'test', notification_id='Beer 2')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.hass.states.entity_ids()) == 1
|
||||
state = self.hass.states.get('persistent_notification.beer_2')
|
||||
assert len(notifications) == 1
|
||||
|
||||
entity_id = 'persistent_notification.beer_2'
|
||||
state = self.hass.states.get(entity_id)
|
||||
assert state.attributes.get('message') == 'test'
|
||||
|
||||
notification = notifications.get(entity_id)
|
||||
assert notification['message'] == 'test'
|
||||
assert notification['title'] is None
|
||||
|
||||
pn.create(self.hass, 'test 2', notification_id='Beer 2')
|
||||
self.hass.block_till_done()
|
||||
|
||||
# We should have overwritten old one
|
||||
assert len(self.hass.states.entity_ids()) == 1
|
||||
state = self.hass.states.get('persistent_notification.beer_2')
|
||||
state = self.hass.states.get(entity_id)
|
||||
assert state.attributes.get('message') == 'test 2'
|
||||
|
||||
notification = notifications.get(entity_id)
|
||||
assert notification['message'] == 'test 2'
|
||||
notifications.clear()
|
||||
|
||||
def test_create_template_error(self):
|
||||
"""Ensure we output templates if contain error."""
|
||||
notifications = self.hass.data[pn.DOMAIN]['notifications']
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
|
||||
assert len(notifications) == 0
|
||||
|
||||
pn.create(self.hass, '{{ message + 1 }}', '{{ title + 1 }}')
|
||||
self.hass.block_till_done()
|
||||
|
||||
entity_ids = self.hass.states.entity_ids(pn.DOMAIN)
|
||||
assert len(entity_ids) == 1
|
||||
assert len(notifications) == 1
|
||||
|
||||
state = self.hass.states.get(entity_ids[0])
|
||||
assert state.attributes.get('message') == '{{ message + 1 }}'
|
||||
assert state.attributes.get('title') == '{{ title + 1 }}'
|
||||
|
||||
notification = notifications.get(entity_ids[0])
|
||||
assert notification['message'] == '{{ message + 1 }}'
|
||||
assert notification['title'] == '{{ title + 1 }}'
|
||||
notifications.clear()
|
||||
|
||||
def test_dismiss_notification(self):
|
||||
"""Ensure removal of specific notification."""
|
||||
notifications = self.hass.data[pn.DOMAIN]['notifications']
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
|
||||
assert len(notifications) == 0
|
||||
|
||||
pn.create(self.hass, 'test', notification_id='Beer 2')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 1
|
||||
assert len(notifications) == 1
|
||||
pn.dismiss(self.hass, notification_id='Beer 2')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
|
||||
assert len(notifications) == 0
|
||||
notifications.clear()
|
||||
|
||||
def test_mark_read(self):
|
||||
"""Ensure notification is marked as Read."""
|
||||
notifications = self.hass.data[pn.DOMAIN]['notifications']
|
||||
assert len(notifications) == 0
|
||||
|
||||
pn.create(self.hass, 'test', notification_id='Beer 2')
|
||||
self.hass.block_till_done()
|
||||
|
||||
entity_id = 'persistent_notification.beer_2'
|
||||
assert len(notifications) == 1
|
||||
notification = notifications.get(entity_id)
|
||||
assert notification['status'] == pn.STATUS_UNREAD
|
||||
|
||||
self.hass.services.call(pn.DOMAIN, pn.SERVICE_MARK_READ, {
|
||||
'notification_id': 'Beer 2'
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(notifications) == 1
|
||||
notification = notifications.get(entity_id)
|
||||
assert notification['status'] == pn.STATUS_READ
|
||||
notifications.clear()
|
||||
|
||||
|
||||
async def test_ws_get_notifications(hass, hass_ws_client):
|
||||
"""Test websocket endpoint for retrieving persistent notifications."""
|
||||
await async_setup_component(hass, pn.DOMAIN, {})
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': 'persistent_notification/get'
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['success']
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 0
|
||||
|
||||
# Create
|
||||
hass.components.persistent_notification.async_create(
|
||||
'test', notification_id='Beer 2')
|
||||
await client.send_json({
|
||||
'id': 6,
|
||||
'type': 'persistent_notification/get'
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
assert msg['id'] == 6
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['success']
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 1
|
||||
notification = notifications[0]
|
||||
assert notification['notification_id'] == 'Beer 2'
|
||||
assert notification['message'] == 'test'
|
||||
assert notification['title'] is None
|
||||
assert notification['status'] == pn.STATUS_UNREAD
|
||||
|
||||
# Mark Read
|
||||
await hass.services.async_call(pn.DOMAIN, pn.SERVICE_MARK_READ, {
|
||||
'notification_id': 'Beer 2'
|
||||
})
|
||||
await client.send_json({
|
||||
'id': 7,
|
||||
'type': 'persistent_notification/get'
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 1
|
||||
assert notifications[0]['status'] == pn.STATUS_READ
|
||||
|
||||
# Dismiss
|
||||
hass.components.persistent_notification.async_dismiss('Beer 2')
|
||||
await client.send_json({
|
||||
'id': 8,
|
||||
'type': 'persistent_notification/get'
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 0
|
||||
|
|
Loading…
Reference in New Issue