diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 8d4cd0a5bbf..478b29c75b2 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -26,11 +26,12 @@ CONNECTION_ZIGBEE = 'zigbee' class DeviceEntry: """Device Registry Entry.""" - config_entries = attr.ib(type=set, converter=set) - connections = attr.ib(type=set, converter=set) - identifiers = attr.ib(type=set, converter=set) - manufacturer = attr.ib(type=str) - model = attr.ib(type=str) + config_entries = attr.ib(type=set, converter=set, + default=attr.Factory(set)) + connections = attr.ib(type=set, converter=set, default=attr.Factory(set)) + identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set)) + manufacturer = attr.ib(type=str, default=None) + model = attr.ib(type=str, default=None) name = attr.ib(type=str, default=None) sw_version = attr.ib(type=str, default=None) hub_device_id = attr.ib(type=str, default=None) @@ -56,46 +57,53 @@ class DeviceRegistry: return None @callback - def async_get_or_create(self, *, config_entry_id, connections, identifiers, - manufacturer, model, name=None, sw_version=None, + def async_get_or_create(self, *, config_entry_id, connections=None, + identifiers=None, manufacturer=_UNDEF, + model=_UNDEF, name=_UNDEF, sw_version=_UNDEF, via_hub=None): """Get device. Create if it doesn't exist.""" if not identifiers and not connections: return None + if identifiers is None: + identifiers = set() + + if connections is None: + connections = set() + device = self.async_get_device(identifiers, connections) + if device is None: + device = DeviceEntry() + self.devices[device.id] = device + if via_hub is not None: hub_device = self.async_get_device({via_hub}, set()) - hub_device_id = hub_device.id if hub_device else None + hub_device_id = hub_device.id if hub_device else _UNDEF else: - hub_device_id = None + hub_device_id = _UNDEF - if device is not None: - return self._async_update_device( - device.id, config_entry_id=config_entry_id, - hub_device_id=hub_device_id - ) - - device = DeviceEntry( - config_entries={config_entry_id}, - connections=connections, - identifiers=identifiers, + return self._async_update_device( + device.id, + add_config_entry_id=config_entry_id, + hub_device_id=hub_device_id, + merge_connections=connections, + merge_identifiers=identifiers, manufacturer=manufacturer, model=model, name=name, sw_version=sw_version, - hub_device_id=hub_device_id ) - self.devices[device.id] = device - - self.async_schedule_save() - - return device @callback - def _async_update_device(self, device_id, *, config_entry_id=_UNDEF, + def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF, remove_config_entry_id=_UNDEF, + merge_connections=_UNDEF, + merge_identifiers=_UNDEF, + manufacturer=_UNDEF, + model=_UNDEF, + name=_UNDEF, + sw_version=_UNDEF, hub_device_id=_UNDEF): """Update device attributes.""" old = self.devices[device_id] @@ -104,21 +112,34 @@ class DeviceRegistry: config_entries = old.config_entries - if (config_entry_id is not _UNDEF and - config_entry_id not in old.config_entries): - config_entries = old.config_entries | {config_entry_id} + if (add_config_entry_id is not _UNDEF and + add_config_entry_id not in old.config_entries): + config_entries = old.config_entries | {add_config_entry_id} if (remove_config_entry_id is not _UNDEF and remove_config_entry_id in config_entries): - config_entries = set(config_entries) - config_entries.remove(remove_config_entry_id) + config_entries = config_entries - {remove_config_entry_id} if config_entries is not old.config_entries: changes['config_entries'] = config_entries - if (hub_device_id is not _UNDEF and - hub_device_id != old.hub_device_id): - changes['hub_device_id'] = hub_device_id + for attr_name, value in ( + ('connections', merge_connections), + ('identifiers', merge_identifiers), + ): + old_value = getattr(old, attr_name) + if value is not _UNDEF and value != old_value: + changes[attr_name] = old_value | value + + for attr_name, value in ( + ('manufacturer', manufacturer), + ('model', model), + ('name', name), + ('sw_version', sw_version), + ('hub_device_id', hub_device_id), + ): + if value is not _UNDEF and value != getattr(old, attr_name): + changes[attr_name] = value if not changes: return old diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index f8ea51cfdc8..87eb0fb2d6f 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -27,7 +27,6 @@ async def test_list_devices(hass, client, registry): manufacturer='manufacturer', model='model') registry.async_get_or_create( config_entry_id='1234', - connections={}, identifiers={('bridgeid', '1234')}, manufacturer='manufacturer', model='model', via_hub=('bridgeid', '0123')) diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index b251846c491..a87ad3d483a 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -17,7 +17,10 @@ async def test_get_or_create_returns_same_entry(registry): config_entry_id='1234', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, - manufacturer='manufacturer', model='model') + sw_version='sw-version', + name='name', + manufacturer='manufacturer', + model='model') entry2 = registry.async_get_or_create( config_entry_id='1234', connections={('ethernet', '11:22:33:44:55:66:77:88')}, @@ -25,15 +28,19 @@ async def test_get_or_create_returns_same_entry(registry): manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( config_entry_id='1234', - connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, - identifiers={('bridgeid', '1234')}, - manufacturer='manufacturer', model='model') + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')} + ) assert len(registry.devices) == 1 assert entry.id == entry2.id assert entry.id == entry3.id assert entry.identifiers == {('bridgeid', '0123')} + assert entry3.manufacturer == 'manufacturer' + assert entry3.model == 'model' + assert entry3.name == 'name' + assert entry3.sw_version == 'sw-version' + async def test_requirement_for_identifier_or_connection(registry): """Make sure we do require some descriptor of device."""