core/homeassistant/helpers/device_registry.py

273 lines
9.1 KiB
Python
Raw Normal View History

"""Provide a way to connect entities belonging to one device."""
import logging
import uuid
from collections import OrderedDict
import attr
from homeassistant.core import callback
from homeassistant.loader import bind_hass
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DATA_REGISTRY = 'device_registry'
STORAGE_KEY = 'core.device_registry'
STORAGE_VERSION = 1
SAVE_DELAY = 10
CONNECTION_NETWORK_MAC = 'mac'
CONNECTION_UPNP = 'upnp'
CONNECTION_ZIGBEE = 'zigbee'
@attr.s(slots=True, frozen=True)
class DeviceEntry:
"""Device Registry Entry."""
config_entries = attr.ib(type=set, converter=set,
default=attr.Factory(set))
connections = attr.ib(type=set, converter=set, default=attr.Factory(set))
identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set))
manufacturer = attr.ib(type=str, default=None)
model = attr.ib(type=str, default=None)
name = attr.ib(type=str, default=None)
sw_version = attr.ib(type=str, default=None)
hub_device_id = attr.ib(type=str, default=None)
area_id = attr.ib(type=str, default=None)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
def format_mac(mac):
"""Format the mac address string for entry into dev reg."""
to_test = mac
if len(to_test) == 17 and to_test.count(':') == 5:
return to_test.lower()
2018-11-06 18:27:52 +00:00
if len(to_test) == 17 and to_test.count('-') == 5:
to_test = to_test.replace('-', '')
elif len(to_test) == 14 and to_test.count('.') == 2:
to_test = to_test.replace('.', '')
if len(to_test) == 12:
# no : included
return ':'.join(to_test.lower()[i:i + 2] for i in range(0, 12, 2))
# Not sure how formatted, return original
return mac
class DeviceRegistry:
"""Class to hold a registry of devices."""
def __init__(self, hass):
"""Initialize the device registry."""
self.hass = hass
self.devices = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback
def async_get_device(self, identifiers: set, connections: set):
"""Check if device is registered."""
for device in self.devices.values():
if any(iden in device.identifiers for iden in identifiers) or \
any(conn in device.connections for conn in connections):
return device
return None
@callback
def async_get_or_create(self, *, config_entry_id, connections=None,
identifiers=None, manufacturer=_UNDEF,
model=_UNDEF, name=_UNDEF, sw_version=_UNDEF,
via_hub=None):
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
if identifiers is None:
identifiers = set()
if connections is None:
connections = set()
connections = {
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC
else (key, value)
for key, value in connections
}
device = self.async_get_device(identifiers, connections)
if device is None:
device = DeviceEntry()
self.devices[device.id] = device
if via_hub is not None:
hub_device = self.async_get_device({via_hub}, set())
hub_device_id = hub_device.id if hub_device else _UNDEF
else:
hub_device_id = _UNDEF
return self._async_update_device(
device.id,
add_config_entry_id=config_entry_id,
hub_device_id=hub_device_id,
merge_connections=connections or _UNDEF,
merge_identifiers=identifiers or _UNDEF,
manufacturer=manufacturer,
model=model,
name=name,
sw_version=sw_version
)
@callback
def async_update_device(self, device_id, *, area_id=_UNDEF):
"""Update properties of a device."""
return self._async_update_device(device_id, area_id=area_id)
@callback
def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF,
remove_config_entry_id=_UNDEF,
merge_connections=_UNDEF,
merge_identifiers=_UNDEF,
manufacturer=_UNDEF,
model=_UNDEF,
name=_UNDEF,
sw_version=_UNDEF,
hub_device_id=_UNDEF,
area_id=_UNDEF):
"""Update device attributes."""
old = self.devices[device_id]
changes = {}
config_entries = old.config_entries
if (add_config_entry_id is not _UNDEF and
add_config_entry_id not in old.config_entries):
config_entries = old.config_entries | {add_config_entry_id}
if (remove_config_entry_id is not _UNDEF and
remove_config_entry_id in config_entries):
config_entries = config_entries - {remove_config_entry_id}
if config_entries is not old.config_entries:
changes['config_entries'] = config_entries
for attr_name, value in (
('connections', merge_connections),
('identifiers', merge_identifiers),
):
old_value = getattr(old, attr_name)
# If not undefined, check if `value` contains new items.
if value is not _UNDEF and not value.issubset(old_value):
changes[attr_name] = old_value | value
for attr_name, value in (
('manufacturer', manufacturer),
('model', model),
('name', name),
('sw_version', sw_version),
('hub_device_id', hub_device_id),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
if (area_id is not _UNDEF and area_id != old.area_id):
changes['area_id'] = area_id
if not changes:
return old
new = self.devices[device_id] = attr.evolve(old, **changes)
self.async_schedule_save()
return new
async def async_load(self):
"""Load the device registry."""
data = await self._store.async_load()
devices = OrderedDict()
if data is not None:
for device in data['devices']:
devices[device['id']] = DeviceEntry(
config_entries=set(device['config_entries']),
connections={tuple(conn) for conn
in device['connections']},
identifiers={tuple(iden) for iden
in device['identifiers']},
manufacturer=device['manufacturer'],
model=device['model'],
name=device['name'],
sw_version=device['sw_version'],
id=device['id'],
# Introduced in 0.79
hub_device_id=device.get('hub_device_id'),
# Introduced in 0.87
area_id=device.get('area_id')
)
self.devices = devices
@callback
def async_schedule_save(self):
"""Schedule saving the device registry."""
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
@callback
def _data_to_save(self):
"""Return data of device registry to store in a file."""
data = {}
data['devices'] = [
{
'config_entries': list(entry.config_entries),
'connections': list(entry.connections),
'identifiers': list(entry.identifiers),
'manufacturer': entry.manufacturer,
'model': entry.model,
'name': entry.name,
'sw_version': entry.sw_version,
'id': entry.id,
'hub_device_id': entry.hub_device_id,
'area_id': entry.area_id
} for entry in self.devices.values()
]
return data
@callback
def async_clear_config_entry(self, config_entry_id):
"""Clear config entry from registry entries."""
for dev_id, device in self.devices.items():
if config_entry_id in device.config_entries:
self._async_update_device(
dev_id, remove_config_entry_id=config_entry_id)
@callback
def async_clear_area_id(self, area_id: str) -> None:
"""Clear area id from registry entries."""
for dev_id, device in self.devices.items():
if area_id == device.area_id:
self._async_update_device(dev_id, area_id=None)
@bind_hass
async def async_get_registry(hass) -> DeviceRegistry:
"""Return device registry instance."""
task = hass.data.get(DATA_REGISTRY)
if task is None:
async def _load_reg():
registry = DeviceRegistry(hass)
await registry.async_load()
return registry
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return await task