306 lines
10 KiB
Python
306 lines
10 KiB
Python
"""Provide a way to connect entities belonging to one device."""
|
|
import logging
|
|
import uuid
|
|
from asyncio import Event
|
|
from collections import OrderedDict
|
|
from typing import List, Optional, cast
|
|
|
|
import attr
|
|
|
|
from homeassistant.core import callback
|
|
from homeassistant.loader import bind_hass
|
|
|
|
from .typing import HomeAssistantType
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
_UNDEF = object()
|
|
|
|
DATA_REGISTRY = 'device_registry'
|
|
|
|
STORAGE_KEY = 'core.device_registry'
|
|
STORAGE_VERSION = 1
|
|
SAVE_DELAY = 10
|
|
|
|
CONNECTION_NETWORK_MAC = 'mac'
|
|
CONNECTION_UPNP = 'upnp'
|
|
CONNECTION_ZIGBEE = 'zigbee'
|
|
|
|
|
|
@attr.s(slots=True, frozen=True)
|
|
class DeviceEntry:
|
|
"""Device Registry Entry."""
|
|
|
|
config_entries = attr.ib(type=set, converter=set,
|
|
default=attr.Factory(set))
|
|
connections = attr.ib(type=set, converter=set, default=attr.Factory(set))
|
|
identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set))
|
|
manufacturer = attr.ib(type=str, default=None)
|
|
model = attr.ib(type=str, default=None)
|
|
name = attr.ib(type=str, default=None)
|
|
sw_version = attr.ib(type=str, default=None)
|
|
hub_device_id = attr.ib(type=str, default=None)
|
|
area_id = attr.ib(type=str, default=None)
|
|
name_by_user = attr.ib(type=str, default=None)
|
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
|
|
|
|
|
def format_mac(mac):
|
|
"""Format the mac address string for entry into dev reg."""
|
|
to_test = mac
|
|
|
|
if len(to_test) == 17 and to_test.count(':') == 5:
|
|
return to_test.lower()
|
|
|
|
if len(to_test) == 17 and to_test.count('-') == 5:
|
|
to_test = to_test.replace('-', '')
|
|
elif len(to_test) == 14 and to_test.count('.') == 2:
|
|
to_test = to_test.replace('.', '')
|
|
|
|
if len(to_test) == 12:
|
|
# no : included
|
|
return ':'.join(to_test.lower()[i:i + 2] for i in range(0, 12, 2))
|
|
|
|
# Not sure how formatted, return original
|
|
return mac
|
|
|
|
|
|
class DeviceRegistry:
|
|
"""Class to hold a registry of devices."""
|
|
|
|
def __init__(self, hass):
|
|
"""Initialize the device registry."""
|
|
self.hass = hass
|
|
self.devices = None
|
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
|
|
|
@callback
|
|
def async_get(self, device_id: str) -> Optional[DeviceEntry]:
|
|
"""Get device."""
|
|
return self.devices.get(device_id)
|
|
|
|
@callback
|
|
def async_get_device(self, identifiers: set, connections: set):
|
|
"""Check if device is registered."""
|
|
for device in self.devices.values():
|
|
if any(iden in device.identifiers for iden in identifiers) or \
|
|
any(conn in device.connections for conn in connections):
|
|
return device
|
|
return None
|
|
|
|
@callback
|
|
def async_get_or_create(self, *, config_entry_id, connections=None,
|
|
identifiers=None, manufacturer=_UNDEF,
|
|
model=_UNDEF, name=_UNDEF, sw_version=_UNDEF,
|
|
via_hub=None):
|
|
"""Get device. Create if it doesn't exist."""
|
|
if not identifiers and not connections:
|
|
return None
|
|
|
|
if identifiers is None:
|
|
identifiers = set()
|
|
|
|
if connections is None:
|
|
connections = set()
|
|
|
|
connections = {
|
|
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC
|
|
else (key, value)
|
|
for key, value in connections
|
|
}
|
|
|
|
device = self.async_get_device(identifiers, connections)
|
|
|
|
if device is None:
|
|
device = DeviceEntry()
|
|
self.devices[device.id] = device
|
|
|
|
if via_hub is not None:
|
|
hub_device = self.async_get_device({via_hub}, set())
|
|
hub_device_id = hub_device.id if hub_device else _UNDEF
|
|
else:
|
|
hub_device_id = _UNDEF
|
|
|
|
return self._async_update_device(
|
|
device.id,
|
|
add_config_entry_id=config_entry_id,
|
|
hub_device_id=hub_device_id,
|
|
merge_connections=connections or _UNDEF,
|
|
merge_identifiers=identifiers or _UNDEF,
|
|
manufacturer=manufacturer,
|
|
model=model,
|
|
name=name,
|
|
sw_version=sw_version
|
|
)
|
|
|
|
@callback
|
|
def async_update_device(
|
|
self, device_id, *, area_id=_UNDEF, name_by_user=_UNDEF):
|
|
"""Update properties of a device."""
|
|
return self._async_update_device(
|
|
device_id, area_id=area_id, name_by_user=name_by_user)
|
|
|
|
@callback
|
|
def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF,
|
|
remove_config_entry_id=_UNDEF,
|
|
merge_connections=_UNDEF,
|
|
merge_identifiers=_UNDEF,
|
|
manufacturer=_UNDEF,
|
|
model=_UNDEF,
|
|
name=_UNDEF,
|
|
sw_version=_UNDEF,
|
|
hub_device_id=_UNDEF,
|
|
area_id=_UNDEF,
|
|
name_by_user=_UNDEF):
|
|
"""Update device attributes."""
|
|
old = self.devices[device_id]
|
|
|
|
changes = {}
|
|
|
|
config_entries = old.config_entries
|
|
|
|
if (add_config_entry_id is not _UNDEF and
|
|
add_config_entry_id not in old.config_entries):
|
|
config_entries = old.config_entries | {add_config_entry_id}
|
|
|
|
if (remove_config_entry_id is not _UNDEF and
|
|
remove_config_entry_id in config_entries):
|
|
config_entries = config_entries - {remove_config_entry_id}
|
|
|
|
if config_entries is not old.config_entries:
|
|
changes['config_entries'] = config_entries
|
|
|
|
for attr_name, value in (
|
|
('connections', merge_connections),
|
|
('identifiers', merge_identifiers),
|
|
):
|
|
old_value = getattr(old, attr_name)
|
|
# If not undefined, check if `value` contains new items.
|
|
if value is not _UNDEF and not value.issubset(old_value):
|
|
changes[attr_name] = old_value | value
|
|
|
|
for attr_name, value in (
|
|
('manufacturer', manufacturer),
|
|
('model', model),
|
|
('name', name),
|
|
('sw_version', sw_version),
|
|
('hub_device_id', hub_device_id),
|
|
):
|
|
if value is not _UNDEF and value != getattr(old, attr_name):
|
|
changes[attr_name] = value
|
|
|
|
if (area_id is not _UNDEF and area_id != old.area_id):
|
|
changes['area_id'] = area_id
|
|
|
|
if (name_by_user is not _UNDEF and
|
|
name_by_user != old.name_by_user):
|
|
changes['name_by_user'] = name_by_user
|
|
|
|
if not changes:
|
|
return old
|
|
|
|
new = self.devices[device_id] = attr.evolve(old, **changes)
|
|
self.async_schedule_save()
|
|
return new
|
|
|
|
async def async_load(self):
|
|
"""Load the device registry."""
|
|
data = await self._store.async_load()
|
|
|
|
devices = OrderedDict()
|
|
|
|
if data is not None:
|
|
for device in data['devices']:
|
|
devices[device['id']] = DeviceEntry(
|
|
config_entries=set(device['config_entries']),
|
|
connections={tuple(conn) for conn
|
|
in device['connections']},
|
|
identifiers={tuple(iden) for iden
|
|
in device['identifiers']},
|
|
manufacturer=device['manufacturer'],
|
|
model=device['model'],
|
|
name=device['name'],
|
|
sw_version=device['sw_version'],
|
|
id=device['id'],
|
|
# Introduced in 0.79
|
|
hub_device_id=device.get('hub_device_id'),
|
|
# Introduced in 0.87
|
|
area_id=device.get('area_id'),
|
|
name_by_user=device.get('name_by_user')
|
|
)
|
|
|
|
self.devices = devices
|
|
|
|
@callback
|
|
def async_schedule_save(self):
|
|
"""Schedule saving the device registry."""
|
|
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
|
|
|
@callback
|
|
def _data_to_save(self):
|
|
"""Return data of device registry to store in a file."""
|
|
data = {}
|
|
|
|
data['devices'] = [
|
|
{
|
|
'config_entries': list(entry.config_entries),
|
|
'connections': list(entry.connections),
|
|
'identifiers': list(entry.identifiers),
|
|
'manufacturer': entry.manufacturer,
|
|
'model': entry.model,
|
|
'name': entry.name,
|
|
'sw_version': entry.sw_version,
|
|
'id': entry.id,
|
|
'hub_device_id': entry.hub_device_id,
|
|
'area_id': entry.area_id,
|
|
'name_by_user': entry.name_by_user
|
|
} for entry in self.devices.values()
|
|
]
|
|
|
|
return data
|
|
|
|
@callback
|
|
def async_clear_config_entry(self, config_entry_id):
|
|
"""Clear config entry from registry entries."""
|
|
for dev_id, device in self.devices.items():
|
|
if config_entry_id in device.config_entries:
|
|
self._async_update_device(
|
|
dev_id, remove_config_entry_id=config_entry_id)
|
|
|
|
@callback
|
|
def async_clear_area_id(self, area_id: str) -> None:
|
|
"""Clear area id from registry entries."""
|
|
for dev_id, device in self.devices.items():
|
|
if area_id == device.area_id:
|
|
self._async_update_device(dev_id, area_id=None)
|
|
|
|
|
|
@bind_hass
|
|
async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
|
|
"""Return device registry instance."""
|
|
reg_or_evt = hass.data.get(DATA_REGISTRY)
|
|
|
|
if not reg_or_evt:
|
|
evt = hass.data[DATA_REGISTRY] = Event()
|
|
|
|
reg = DeviceRegistry(hass)
|
|
await reg.async_load()
|
|
|
|
hass.data[DATA_REGISTRY] = reg
|
|
evt.set()
|
|
return reg
|
|
|
|
if isinstance(reg_or_evt, Event):
|
|
evt = reg_or_evt
|
|
await evt.wait()
|
|
return cast(DeviceRegistry, hass.data.get(DATA_REGISTRY))
|
|
|
|
return cast(DeviceRegistry, reg_or_evt)
|
|
|
|
|
|
@callback
|
|
def async_entries_for_area(registry: DeviceRegistry, area_id: str) \
|
|
-> List[DeviceEntry]:
|
|
"""Return entries that match an area."""
|
|
return [device for device in registry.devices.values()
|
|
if device.area_id == area_id]
|