334 lines
12 KiB
Python
334 lines
12 KiB
Python
"""Handle MySensors gateways."""
|
|
import asyncio
|
|
from collections import defaultdict
|
|
import logging
|
|
import socket
|
|
import sys
|
|
from timeit import default_timer as timer
|
|
|
|
import async_timeout
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.const import (
|
|
CONF_NAME, CONF_OPTIMISTIC, EVENT_HOMEASSISTANT_STOP)
|
|
from homeassistant.core import callback
|
|
from homeassistant.helpers import discovery
|
|
import homeassistant.helpers.config_validation as cv
|
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from .const import (
|
|
ATTR_DEVICES, CONF_BAUD_RATE, CONF_DEVICE, CONF_GATEWAYS, CONF_NODES,
|
|
CONF_PERSISTENCE, CONF_PERSISTENCE_FILE, CONF_RETAIN, CONF_TCP_PORT,
|
|
CONF_TOPIC_IN_PREFIX, CONF_TOPIC_OUT_PREFIX, CONF_VERSION, DOMAIN,
|
|
MYSENSORS_CONST_SCHEMA, MYSENSORS_GATEWAYS, PLATFORM, SCHEMA,
|
|
SIGNAL_CALLBACK, TYPE)
|
|
from .device import get_mysensors_devices
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
GATEWAY_READY_TIMEOUT = 15.0
|
|
MQTT_COMPONENT = 'mqtt'
|
|
MYSENSORS_GATEWAY_READY = 'mysensors_gateway_ready_{}'
|
|
|
|
|
|
def is_serial_port(value):
|
|
"""Validate that value is a windows serial port or a unix device."""
|
|
if sys.platform.startswith('win'):
|
|
ports = ('COM{}'.format(idx + 1) for idx in range(256))
|
|
if value in ports:
|
|
return value
|
|
raise vol.Invalid('{} is not a serial port'.format(value))
|
|
return cv.isdevice(value)
|
|
|
|
|
|
def is_socket_address(value):
|
|
"""Validate that value is a valid address."""
|
|
try:
|
|
socket.getaddrinfo(value, None)
|
|
return value
|
|
except OSError:
|
|
raise vol.Invalid('Device is not a valid domain name or ip address')
|
|
|
|
|
|
def get_mysensors_gateway(hass, gateway_id):
|
|
"""Return MySensors gateway."""
|
|
if MYSENSORS_GATEWAYS not in hass.data:
|
|
hass.data[MYSENSORS_GATEWAYS] = {}
|
|
gateways = hass.data.get(MYSENSORS_GATEWAYS)
|
|
return gateways.get(gateway_id)
|
|
|
|
|
|
async def setup_gateways(hass, config):
|
|
"""Set up all gateways."""
|
|
conf = config[DOMAIN]
|
|
gateways = {}
|
|
|
|
for index, gateway_conf in enumerate(conf[CONF_GATEWAYS]):
|
|
persistence_file = gateway_conf.get(
|
|
CONF_PERSISTENCE_FILE,
|
|
hass.config.path('mysensors{}.pickle'.format(index + 1)))
|
|
ready_gateway = await _get_gateway(
|
|
hass, config, gateway_conf, persistence_file)
|
|
if ready_gateway is not None:
|
|
gateways[id(ready_gateway)] = ready_gateway
|
|
|
|
return gateways
|
|
|
|
|
|
async def _get_gateway(hass, config, gateway_conf, persistence_file):
|
|
"""Return gateway after setup of the gateway."""
|
|
from mysensors import mysensors
|
|
|
|
conf = config[DOMAIN]
|
|
persistence = conf[CONF_PERSISTENCE]
|
|
version = conf[CONF_VERSION]
|
|
device = gateway_conf[CONF_DEVICE]
|
|
baud_rate = gateway_conf[CONF_BAUD_RATE]
|
|
tcp_port = gateway_conf[CONF_TCP_PORT]
|
|
in_prefix = gateway_conf.get(CONF_TOPIC_IN_PREFIX, '')
|
|
out_prefix = gateway_conf.get(CONF_TOPIC_OUT_PREFIX, '')
|
|
|
|
if device == MQTT_COMPONENT:
|
|
if not await async_setup_component(hass, MQTT_COMPONENT, config):
|
|
return None
|
|
mqtt = hass.components.mqtt
|
|
retain = conf[CONF_RETAIN]
|
|
|
|
def pub_callback(topic, payload, qos, retain):
|
|
"""Call MQTT publish function."""
|
|
mqtt.async_publish(topic, payload, qos, retain)
|
|
|
|
def sub_callback(topic, sub_cb, qos):
|
|
"""Call MQTT subscribe function."""
|
|
@callback
|
|
def internal_callback(*args):
|
|
"""Call callback."""
|
|
sub_cb(*args)
|
|
|
|
hass.async_create_task(
|
|
mqtt.async_subscribe(topic, internal_callback, qos))
|
|
|
|
gateway = mysensors.AsyncMQTTGateway(
|
|
pub_callback, sub_callback, in_prefix=in_prefix,
|
|
out_prefix=out_prefix, retain=retain, loop=hass.loop,
|
|
event_callback=None, persistence=persistence,
|
|
persistence_file=persistence_file,
|
|
protocol_version=version)
|
|
else:
|
|
try:
|
|
await hass.async_add_job(is_serial_port, device)
|
|
gateway = mysensors.AsyncSerialGateway(
|
|
device, baud=baud_rate, loop=hass.loop,
|
|
event_callback=None, persistence=persistence,
|
|
persistence_file=persistence_file,
|
|
protocol_version=version)
|
|
except vol.Invalid:
|
|
try:
|
|
await hass.async_add_job(is_socket_address, device)
|
|
# valid ip address
|
|
gateway = mysensors.AsyncTCPGateway(
|
|
device, port=tcp_port, loop=hass.loop, event_callback=None,
|
|
persistence=persistence, persistence_file=persistence_file,
|
|
protocol_version=version)
|
|
except vol.Invalid:
|
|
# invalid ip address
|
|
return None
|
|
gateway.metric = hass.config.units.is_metric
|
|
gateway.optimistic = conf[CONF_OPTIMISTIC]
|
|
gateway.device = device
|
|
gateway.event_callback = _gw_callback_factory(hass, config)
|
|
gateway.nodes_config = gateway_conf[CONF_NODES]
|
|
if persistence:
|
|
await gateway.start_persistence()
|
|
|
|
return gateway
|
|
|
|
|
|
async def finish_setup(hass, hass_config, gateways):
|
|
"""Load any persistent devices and platforms and start gateway."""
|
|
discover_tasks = []
|
|
start_tasks = []
|
|
for gateway in gateways.values():
|
|
discover_tasks.append(_discover_persistent_devices(
|
|
hass, hass_config, gateway))
|
|
start_tasks.append(_gw_start(hass, gateway))
|
|
if discover_tasks:
|
|
# Make sure all devices and platforms are loaded before gateway start.
|
|
await asyncio.wait(discover_tasks, loop=hass.loop)
|
|
if start_tasks:
|
|
await asyncio.wait(start_tasks, loop=hass.loop)
|
|
|
|
|
|
async def _discover_persistent_devices(hass, hass_config, gateway):
|
|
"""Discover platforms for devices loaded via persistence file."""
|
|
tasks = []
|
|
new_devices = defaultdict(list)
|
|
for node_id in gateway.sensors:
|
|
node = gateway.sensors[node_id]
|
|
for child in node.children.values():
|
|
validated = _validate_child(gateway, node_id, child)
|
|
for platform, dev_ids in validated.items():
|
|
new_devices[platform].extend(dev_ids)
|
|
for platform, dev_ids in new_devices.items():
|
|
tasks.append(_discover_mysensors_platform(
|
|
hass, hass_config, platform, dev_ids))
|
|
if tasks:
|
|
await asyncio.wait(tasks, loop=hass.loop)
|
|
|
|
|
|
@callback
|
|
def _discover_mysensors_platform(hass, hass_config, platform, new_devices):
|
|
"""Discover a MySensors platform."""
|
|
task = hass.async_create_task(discovery.async_load_platform(
|
|
hass, platform, DOMAIN,
|
|
{ATTR_DEVICES: new_devices, CONF_NAME: DOMAIN}, hass_config))
|
|
return task
|
|
|
|
|
|
async def _gw_start(hass, gateway):
|
|
"""Start the gateway."""
|
|
# Don't use hass.async_create_task to avoid holding up setup indefinitely.
|
|
connect_task = hass.loop.create_task(gateway.start())
|
|
|
|
@callback
|
|
def gw_stop(event):
|
|
"""Trigger to stop the gateway."""
|
|
hass.async_create_task(gateway.stop())
|
|
if not connect_task.done():
|
|
connect_task.cancel()
|
|
|
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, gw_stop)
|
|
if gateway.device == 'mqtt':
|
|
# Gatways connected via mqtt doesn't send gateway ready message.
|
|
return
|
|
gateway_ready = asyncio.Future()
|
|
gateway_ready_key = MYSENSORS_GATEWAY_READY.format(id(gateway))
|
|
hass.data[gateway_ready_key] = gateway_ready
|
|
|
|
try:
|
|
with async_timeout.timeout(GATEWAY_READY_TIMEOUT, loop=hass.loop):
|
|
await gateway_ready
|
|
except asyncio.TimeoutError:
|
|
_LOGGER.warning(
|
|
"Gateway %s not ready after %s secs so continuing with setup",
|
|
gateway.device, GATEWAY_READY_TIMEOUT)
|
|
finally:
|
|
hass.data.pop(gateway_ready_key, None)
|
|
|
|
|
|
def _gw_callback_factory(hass, hass_config):
|
|
"""Return a new callback for the gateway."""
|
|
@callback
|
|
def mysensors_callback(msg):
|
|
"""Handle messages from a MySensors gateway."""
|
|
start = timer()
|
|
_LOGGER.debug(
|
|
"Node update: node %s child %s", msg.node_id, msg.child_id)
|
|
|
|
_set_gateway_ready(hass, msg)
|
|
|
|
try:
|
|
child = msg.gateway.sensors[msg.node_id].children[msg.child_id]
|
|
except KeyError:
|
|
_LOGGER.debug("Not a child update for node %s", msg.node_id)
|
|
return
|
|
|
|
signals = []
|
|
|
|
# Update all platforms for the device via dispatcher.
|
|
# Add/update entity if schema validates to true.
|
|
validated = _validate_child(msg.gateway, msg.node_id, child)
|
|
for platform, dev_ids in validated.items():
|
|
devices = get_mysensors_devices(hass, platform)
|
|
new_dev_ids = []
|
|
for dev_id in dev_ids:
|
|
if dev_id in devices:
|
|
signals.append(SIGNAL_CALLBACK.format(*dev_id))
|
|
else:
|
|
new_dev_ids.append(dev_id)
|
|
if new_dev_ids:
|
|
_discover_mysensors_platform(
|
|
hass, hass_config, platform, new_dev_ids)
|
|
for signal in set(signals):
|
|
# Only one signal per device is needed.
|
|
# A device can have multiple platforms, ie multiple schemas.
|
|
# FOR LATER: Add timer to not signal if another update comes in.
|
|
async_dispatcher_send(hass, signal)
|
|
end = timer()
|
|
if end - start > 0.1:
|
|
_LOGGER.debug(
|
|
"Callback for node %s child %s took %.3f seconds",
|
|
msg.node_id, msg.child_id, end - start)
|
|
return mysensors_callback
|
|
|
|
|
|
@callback
|
|
def _set_gateway_ready(hass, msg):
|
|
"""Set asyncio future result if gateway is ready."""
|
|
if (msg.type != msg.gateway.const.MessageType.internal or
|
|
msg.sub_type != msg.gateway.const.Internal.I_GATEWAY_READY):
|
|
return
|
|
gateway_ready = hass.data.get(MYSENSORS_GATEWAY_READY.format(
|
|
id(msg.gateway)))
|
|
if gateway_ready is None or gateway_ready.cancelled():
|
|
return
|
|
gateway_ready.set_result(True)
|
|
|
|
|
|
def _validate_child(gateway, node_id, child):
|
|
"""Validate that a child has the correct values according to schema.
|
|
|
|
Return a dict of platform with a list of device ids for validated devices.
|
|
"""
|
|
validated = defaultdict(list)
|
|
|
|
if not child.values:
|
|
_LOGGER.debug(
|
|
"No child values for node %s child %s", node_id, child.id)
|
|
return validated
|
|
if gateway.sensors[node_id].sketch_name is None:
|
|
_LOGGER.debug("Node %s is missing sketch name", node_id)
|
|
return validated
|
|
pres = gateway.const.Presentation
|
|
set_req = gateway.const.SetReq
|
|
s_name = next(
|
|
(member.name for member in pres if member.value == child.type), None)
|
|
if s_name not in MYSENSORS_CONST_SCHEMA:
|
|
_LOGGER.warning("Child type %s is not supported", s_name)
|
|
return validated
|
|
child_schemas = MYSENSORS_CONST_SCHEMA[s_name]
|
|
|
|
def msg(name):
|
|
"""Return a message for an invalid schema."""
|
|
return "{} requires value_type {}".format(
|
|
pres(child.type).name, set_req[name].name)
|
|
|
|
for schema in child_schemas:
|
|
platform = schema[PLATFORM]
|
|
v_name = schema[TYPE]
|
|
value_type = next(
|
|
(member.value for member in set_req if member.name == v_name),
|
|
None)
|
|
if value_type is None:
|
|
continue
|
|
_child_schema = child.get_schema(gateway.protocol_version)
|
|
vol_schema = _child_schema.extend(
|
|
{vol.Required(set_req[key].value, msg=msg(key)):
|
|
_child_schema.schema.get(set_req[key].value, val)
|
|
for key, val in schema.get(SCHEMA, {v_name: cv.string}).items()},
|
|
extra=vol.ALLOW_EXTRA)
|
|
try:
|
|
vol_schema(child.values)
|
|
except vol.Invalid as exc:
|
|
level = (logging.WARNING if value_type in child.values
|
|
else logging.DEBUG)
|
|
_LOGGER.log(
|
|
level,
|
|
"Invalid values: %s: %s platform: node %s child %s: %s",
|
|
child.values, platform, node_id, child.id, exc)
|
|
continue
|
|
dev_id = id(gateway), node_id, child.id, value_type
|
|
validated[platform].append(dev_id)
|
|
return validated
|