Refactor zwave_js setup entry (#72414)
* Refactor zwave_js setup entry * Improve messagepull/72419/head
parent
6245d28907
commit
a5e100176b
homeassistant/components/zwave_js
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
from async_timeout import timeout
|
||||
from zwave_js_server.client import Client as ZwaveClient
|
||||
from zwave_js_server.exceptions import BaseZwaveJSServerError, InvalidServerVersion
|
||||
from zwave_js_server.model.driver import Driver
|
||||
from zwave_js_server.model.node import Node as ZwaveNode
|
||||
from zwave_js_server.model.notification import (
|
||||
EntryControlNotification,
|
||||
|
@ -25,12 +26,6 @@ from homeassistant.const import (
|
|||
ATTR_DEVICE_ID,
|
||||
ATTR_DOMAIN,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_IDENTIFIERS,
|
||||
ATTR_MANUFACTURER,
|
||||
ATTR_MODEL,
|
||||
ATTR_NAME,
|
||||
ATTR_SUGGESTED_AREA,
|
||||
ATTR_SW_VERSION,
|
||||
CONF_URL,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
|
@ -39,7 +34,7 @@ from homeassistant.exceptions import ConfigEntryNotReady
|
|||
from homeassistant.helpers import device_registry, entity_registry
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType
|
||||
|
||||
from .addon import AddonError, AddonManager, AddonState, get_addon_manager
|
||||
from .api import async_register_api
|
||||
|
@ -154,39 +149,105 @@ def register_node_in_dev_reg(
|
|||
else:
|
||||
ids = {device_id}
|
||||
|
||||
params = {
|
||||
ATTR_IDENTIFIERS: ids,
|
||||
ATTR_SW_VERSION: node.firmware_version,
|
||||
ATTR_NAME: node.name
|
||||
or node.device_config.description
|
||||
or f"Node {node.node_id}",
|
||||
ATTR_MODEL: node.device_config.label,
|
||||
ATTR_MANUFACTURER: node.device_config.manufacturer,
|
||||
}
|
||||
if node.location:
|
||||
params[ATTR_SUGGESTED_AREA] = node.location
|
||||
device = dev_reg.async_get_or_create(config_entry_id=entry.entry_id, **params)
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
identifiers=ids,
|
||||
sw_version=node.firmware_version,
|
||||
name=node.name or node.device_config.description or f"Node {node.node_id}",
|
||||
model=node.device_config.label,
|
||||
manufacturer=node.device_config.manufacturer,
|
||||
suggested_area=node.location if node.location else UNDEFINED,
|
||||
)
|
||||
|
||||
async_dispatcher_send(hass, EVENT_DEVICE_ADDED_TO_REGISTRY, device)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
async def async_setup_entry( # noqa: C901
|
||||
hass: HomeAssistant, entry: ConfigEntry
|
||||
) -> bool:
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Z-Wave JS from a config entry."""
|
||||
if use_addon := entry.data.get(CONF_USE_ADDON):
|
||||
await async_ensure_addon_running(hass, entry)
|
||||
|
||||
client = ZwaveClient(entry.data[CONF_URL], async_get_clientsession(hass))
|
||||
entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {})
|
||||
|
||||
# connect and throw error if connection failed
|
||||
try:
|
||||
async with timeout(CONNECT_TIMEOUT):
|
||||
await client.connect()
|
||||
except InvalidServerVersion as err:
|
||||
if not entry_hass_data.get(DATA_INVALID_SERVER_VERSION_LOGGED):
|
||||
LOGGER.error("Invalid server version: %s", err)
|
||||
entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = True
|
||||
if use_addon:
|
||||
async_ensure_addon_updated(hass)
|
||||
raise ConfigEntryNotReady from err
|
||||
except (asyncio.TimeoutError, BaseZwaveJSServerError) as err:
|
||||
if not entry_hass_data.get(DATA_CONNECT_FAILED_LOGGED):
|
||||
LOGGER.error("Failed to connect: %s", err)
|
||||
entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = True
|
||||
raise ConfigEntryNotReady from err
|
||||
else:
|
||||
LOGGER.info("Connected to Zwave JS Server")
|
||||
entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = False
|
||||
entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = False
|
||||
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
services = ZWaveServices(hass, ent_reg, dev_reg)
|
||||
services.async_register()
|
||||
|
||||
# Set up websocket API
|
||||
async_register_api(hass)
|
||||
|
||||
platform_task = hass.async_create_task(start_platforms(hass, entry, client))
|
||||
entry_hass_data[DATA_START_PLATFORM_TASK] = platform_task
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def start_platforms(
|
||||
hass: HomeAssistant, entry: ConfigEntry, client: ZwaveClient
|
||||
) -> None:
|
||||
"""Start platforms and perform discovery."""
|
||||
entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {})
|
||||
entry_hass_data[DATA_CLIENT] = client
|
||||
entry_hass_data[DATA_PLATFORM_SETUP] = {}
|
||||
driver_ready = asyncio.Event()
|
||||
|
||||
async def handle_ha_shutdown(event: Event) -> None:
|
||||
"""Handle HA shutdown."""
|
||||
await disconnect_client(hass, entry)
|
||||
|
||||
listen_task = asyncio.create_task(client_listen(hass, entry, client, driver_ready))
|
||||
entry_hass_data[DATA_CLIENT_LISTEN_TASK] = listen_task
|
||||
entry.async_on_unload(
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown)
|
||||
)
|
||||
|
||||
try:
|
||||
await driver_ready.wait()
|
||||
except asyncio.CancelledError:
|
||||
LOGGER.debug("Cancelling start platforms")
|
||||
return
|
||||
|
||||
LOGGER.info("Connection to Zwave JS Server initialized")
|
||||
|
||||
if client.driver is None:
|
||||
raise RuntimeError("Driver not ready.")
|
||||
|
||||
await setup_driver(hass, entry, client, client.driver)
|
||||
|
||||
|
||||
async def setup_driver( # noqa: C901
|
||||
hass: HomeAssistant, entry: ConfigEntry, client: ZwaveClient, driver: Driver
|
||||
) -> None:
|
||||
"""Set up devices using the ready driver."""
|
||||
dev_reg = device_registry.async_get(hass)
|
||||
ent_reg = entity_registry.async_get(hass)
|
||||
entry_hass_data: dict = hass.data[DOMAIN].setdefault(entry.entry_id, {})
|
||||
|
||||
entry_hass_data[DATA_CLIENT] = client
|
||||
platform_setup_tasks = entry_hass_data[DATA_PLATFORM_SETUP] = {}
|
||||
|
||||
platform_setup_tasks = entry_hass_data[DATA_PLATFORM_SETUP]
|
||||
registered_unique_ids: dict[str, dict[str, set[str]]] = defaultdict(dict)
|
||||
discovered_value_ids: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
|
@ -384,7 +445,7 @@ async def async_setup_entry( # noqa: C901
|
|||
{
|
||||
ATTR_DOMAIN: DOMAIN,
|
||||
ATTR_NODE_ID: notification.node.node_id,
|
||||
ATTR_HOME_ID: client.driver.controller.home_id,
|
||||
ATTR_HOME_ID: driver.controller.home_id,
|
||||
ATTR_ENDPOINT: notification.endpoint,
|
||||
ATTR_DEVICE_ID: device.id,
|
||||
ATTR_COMMAND_CLASS: notification.command_class,
|
||||
|
@ -414,7 +475,7 @@ async def async_setup_entry( # noqa: C901
|
|||
event_data = {
|
||||
ATTR_DOMAIN: DOMAIN,
|
||||
ATTR_NODE_ID: notification.node.node_id,
|
||||
ATTR_HOME_ID: client.driver.controller.home_id,
|
||||
ATTR_HOME_ID: driver.controller.home_id,
|
||||
ATTR_DEVICE_ID: device.id,
|
||||
ATTR_COMMAND_CLASS: notification.command_class,
|
||||
}
|
||||
|
@ -487,7 +548,7 @@ async def async_setup_entry( # noqa: C901
|
|||
ZWAVE_JS_VALUE_UPDATED_EVENT,
|
||||
{
|
||||
ATTR_NODE_ID: value.node.node_id,
|
||||
ATTR_HOME_ID: client.driver.controller.home_id,
|
||||
ATTR_HOME_ID: driver.controller.home_id,
|
||||
ATTR_DEVICE_ID: device.id,
|
||||
ATTR_ENTITY_ID: entity_id,
|
||||
ATTR_COMMAND_CLASS: value.command_class,
|
||||
|
@ -502,105 +563,42 @@ async def async_setup_entry( # noqa: C901
|
|||
},
|
||||
)
|
||||
|
||||
# connect and throw error if connection failed
|
||||
try:
|
||||
async with timeout(CONNECT_TIMEOUT):
|
||||
await client.connect()
|
||||
except InvalidServerVersion as err:
|
||||
if not entry_hass_data.get(DATA_INVALID_SERVER_VERSION_LOGGED):
|
||||
LOGGER.error("Invalid server version: %s", err)
|
||||
entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = True
|
||||
if use_addon:
|
||||
async_ensure_addon_updated(hass)
|
||||
raise ConfigEntryNotReady from err
|
||||
except (asyncio.TimeoutError, BaseZwaveJSServerError) as err:
|
||||
if not entry_hass_data.get(DATA_CONNECT_FAILED_LOGGED):
|
||||
LOGGER.error("Failed to connect: %s", err)
|
||||
entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = True
|
||||
raise ConfigEntryNotReady from err
|
||||
else:
|
||||
LOGGER.info("Connected to Zwave JS Server")
|
||||
entry_hass_data[DATA_CONNECT_FAILED_LOGGED] = False
|
||||
entry_hass_data[DATA_INVALID_SERVER_VERSION_LOGGED] = False
|
||||
# If opt in preference hasn't been specified yet, we do nothing, otherwise
|
||||
# we apply the preference
|
||||
if opted_in := entry.data.get(CONF_DATA_COLLECTION_OPTED_IN):
|
||||
await async_enable_statistics(client)
|
||||
elif opted_in is False:
|
||||
await driver.async_disable_statistics()
|
||||
|
||||
services = ZWaveServices(hass, ent_reg, dev_reg)
|
||||
services.async_register()
|
||||
# Check for nodes that no longer exist and remove them
|
||||
stored_devices = device_registry.async_entries_for_config_entry(
|
||||
dev_reg, entry.entry_id
|
||||
)
|
||||
known_devices = [
|
||||
dev_reg.async_get_device({get_device_id(client, node)})
|
||||
for node in driver.controller.nodes.values()
|
||||
]
|
||||
|
||||
# Set up websocket API
|
||||
async_register_api(hass)
|
||||
# Devices that are in the device registry that are not known by the controller can be removed
|
||||
for device in stored_devices:
|
||||
if device not in known_devices:
|
||||
dev_reg.async_remove_device(device.id)
|
||||
|
||||
async def start_platforms() -> None:
|
||||
"""Start platforms and perform discovery."""
|
||||
driver_ready = asyncio.Event()
|
||||
# run discovery on all ready nodes
|
||||
await asyncio.gather(
|
||||
*(async_on_node_added(node) for node in driver.controller.nodes.values())
|
||||
)
|
||||
|
||||
async def handle_ha_shutdown(event: Event) -> None:
|
||||
"""Handle HA shutdown."""
|
||||
await disconnect_client(hass, entry)
|
||||
|
||||
listen_task = asyncio.create_task(
|
||||
client_listen(hass, entry, client, driver_ready)
|
||||
# listen for new nodes being added to the mesh
|
||||
entry.async_on_unload(
|
||||
driver.controller.on(
|
||||
"node added",
|
||||
lambda event: hass.async_create_task(async_on_node_added(event["node"])),
|
||||
)
|
||||
entry_hass_data[DATA_CLIENT_LISTEN_TASK] = listen_task
|
||||
entry.async_on_unload(
|
||||
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, handle_ha_shutdown)
|
||||
)
|
||||
|
||||
try:
|
||||
await driver_ready.wait()
|
||||
except asyncio.CancelledError:
|
||||
LOGGER.debug("Cancelling start platforms")
|
||||
return
|
||||
|
||||
LOGGER.info("Connection to Zwave JS Server initialized")
|
||||
|
||||
# If opt in preference hasn't been specified yet, we do nothing, otherwise
|
||||
# we apply the preference
|
||||
if opted_in := entry.data.get(CONF_DATA_COLLECTION_OPTED_IN):
|
||||
await async_enable_statistics(client)
|
||||
elif opted_in is False:
|
||||
await client.driver.async_disable_statistics()
|
||||
|
||||
# Check for nodes that no longer exist and remove them
|
||||
stored_devices = device_registry.async_entries_for_config_entry(
|
||||
dev_reg, entry.entry_id
|
||||
)
|
||||
known_devices = [
|
||||
dev_reg.async_get_device({get_device_id(client, node)})
|
||||
for node in client.driver.controller.nodes.values()
|
||||
]
|
||||
|
||||
# Devices that are in the device registry that are not known by the controller can be removed
|
||||
for device in stored_devices:
|
||||
if device not in known_devices:
|
||||
dev_reg.async_remove_device(device.id)
|
||||
|
||||
# run discovery on all ready nodes
|
||||
await asyncio.gather(
|
||||
*(
|
||||
async_on_node_added(node)
|
||||
for node in client.driver.controller.nodes.values()
|
||||
)
|
||||
)
|
||||
|
||||
# listen for new nodes being added to the mesh
|
||||
entry.async_on_unload(
|
||||
client.driver.controller.on(
|
||||
"node added",
|
||||
lambda event: hass.async_create_task(
|
||||
async_on_node_added(event["node"])
|
||||
),
|
||||
)
|
||||
)
|
||||
# listen for nodes being removed from the mesh
|
||||
# NOTE: This will not remove nodes that were removed when HA was not running
|
||||
entry.async_on_unload(
|
||||
client.driver.controller.on("node removed", async_on_node_removed)
|
||||
)
|
||||
|
||||
platform_task = hass.async_create_task(start_platforms())
|
||||
entry_hass_data[DATA_START_PLATFORM_TASK] = platform_task
|
||||
|
||||
return True
|
||||
)
|
||||
# listen for nodes being removed from the mesh
|
||||
# NOTE: This will not remove nodes that were removed when HA was not running
|
||||
entry.async_on_unload(driver.controller.on("node removed", async_on_node_removed))
|
||||
|
||||
|
||||
async def client_listen(
|
||||
|
|
Loading…
Reference in New Issue