Update new values coming in for dev registry (#16852)
* Update new values coming in for dev registry * fix Lint+Test;2Cpull/16914/head
parent
29db43edb2
commit
da3342f1aa
|
@ -26,11 +26,12 @@ CONNECTION_ZIGBEE = 'zigbee'
|
|||
class DeviceEntry:
|
||||
"""Device Registry Entry."""
|
||||
|
||||
config_entries = attr.ib(type=set, converter=set)
|
||||
connections = attr.ib(type=set, converter=set)
|
||||
identifiers = attr.ib(type=set, converter=set)
|
||||
manufacturer = attr.ib(type=str)
|
||||
model = attr.ib(type=str)
|
||||
config_entries = attr.ib(type=set, converter=set,
|
||||
default=attr.Factory(set))
|
||||
connections = attr.ib(type=set, converter=set, default=attr.Factory(set))
|
||||
identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set))
|
||||
manufacturer = attr.ib(type=str, default=None)
|
||||
model = attr.ib(type=str, default=None)
|
||||
name = attr.ib(type=str, default=None)
|
||||
sw_version = attr.ib(type=str, default=None)
|
||||
hub_device_id = attr.ib(type=str, default=None)
|
||||
|
@ -56,46 +57,53 @@ class DeviceRegistry:
|
|||
return None
|
||||
|
||||
@callback
|
||||
def async_get_or_create(self, *, config_entry_id, connections, identifiers,
|
||||
manufacturer, model, name=None, sw_version=None,
|
||||
def async_get_or_create(self, *, config_entry_id, connections=None,
|
||||
identifiers=None, manufacturer=_UNDEF,
|
||||
model=_UNDEF, name=_UNDEF, sw_version=_UNDEF,
|
||||
via_hub=None):
|
||||
"""Get device. Create if it doesn't exist."""
|
||||
if not identifiers and not connections:
|
||||
return None
|
||||
|
||||
if identifiers is None:
|
||||
identifiers = set()
|
||||
|
||||
if connections is None:
|
||||
connections = set()
|
||||
|
||||
device = self.async_get_device(identifiers, connections)
|
||||
|
||||
if device is None:
|
||||
device = DeviceEntry()
|
||||
self.devices[device.id] = device
|
||||
|
||||
if via_hub is not None:
|
||||
hub_device = self.async_get_device({via_hub}, set())
|
||||
hub_device_id = hub_device.id if hub_device else None
|
||||
hub_device_id = hub_device.id if hub_device else _UNDEF
|
||||
else:
|
||||
hub_device_id = None
|
||||
hub_device_id = _UNDEF
|
||||
|
||||
if device is not None:
|
||||
return self._async_update_device(
|
||||
device.id, config_entry_id=config_entry_id,
|
||||
hub_device_id=hub_device_id
|
||||
)
|
||||
|
||||
device = DeviceEntry(
|
||||
config_entries={config_entry_id},
|
||||
connections=connections,
|
||||
identifiers=identifiers,
|
||||
return self._async_update_device(
|
||||
device.id,
|
||||
add_config_entry_id=config_entry_id,
|
||||
hub_device_id=hub_device_id,
|
||||
merge_connections=connections,
|
||||
merge_identifiers=identifiers,
|
||||
manufacturer=manufacturer,
|
||||
model=model,
|
||||
name=name,
|
||||
sw_version=sw_version,
|
||||
hub_device_id=hub_device_id
|
||||
)
|
||||
self.devices[device.id] = device
|
||||
|
||||
self.async_schedule_save()
|
||||
|
||||
return device
|
||||
|
||||
@callback
|
||||
def _async_update_device(self, device_id, *, config_entry_id=_UNDEF,
|
||||
def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF,
|
||||
remove_config_entry_id=_UNDEF,
|
||||
merge_connections=_UNDEF,
|
||||
merge_identifiers=_UNDEF,
|
||||
manufacturer=_UNDEF,
|
||||
model=_UNDEF,
|
||||
name=_UNDEF,
|
||||
sw_version=_UNDEF,
|
||||
hub_device_id=_UNDEF):
|
||||
"""Update device attributes."""
|
||||
old = self.devices[device_id]
|
||||
|
@ -104,21 +112,34 @@ class DeviceRegistry:
|
|||
|
||||
config_entries = old.config_entries
|
||||
|
||||
if (config_entry_id is not _UNDEF and
|
||||
config_entry_id not in old.config_entries):
|
||||
config_entries = old.config_entries | {config_entry_id}
|
||||
if (add_config_entry_id is not _UNDEF and
|
||||
add_config_entry_id not in old.config_entries):
|
||||
config_entries = old.config_entries | {add_config_entry_id}
|
||||
|
||||
if (remove_config_entry_id is not _UNDEF and
|
||||
remove_config_entry_id in config_entries):
|
||||
config_entries = set(config_entries)
|
||||
config_entries.remove(remove_config_entry_id)
|
||||
config_entries = config_entries - {remove_config_entry_id}
|
||||
|
||||
if config_entries is not old.config_entries:
|
||||
changes['config_entries'] = config_entries
|
||||
|
||||
if (hub_device_id is not _UNDEF and
|
||||
hub_device_id != old.hub_device_id):
|
||||
changes['hub_device_id'] = hub_device_id
|
||||
for attr_name, value in (
|
||||
('connections', merge_connections),
|
||||
('identifiers', merge_identifiers),
|
||||
):
|
||||
old_value = getattr(old, attr_name)
|
||||
if value is not _UNDEF and value != old_value:
|
||||
changes[attr_name] = old_value | value
|
||||
|
||||
for attr_name, value in (
|
||||
('manufacturer', manufacturer),
|
||||
('model', model),
|
||||
('name', name),
|
||||
('sw_version', sw_version),
|
||||
('hub_device_id', hub_device_id),
|
||||
):
|
||||
if value is not _UNDEF and value != getattr(old, attr_name):
|
||||
changes[attr_name] = value
|
||||
|
||||
if not changes:
|
||||
return old
|
||||
|
|
|
@ -27,7 +27,6 @@ async def test_list_devices(hass, client, registry):
|
|||
manufacturer='manufacturer', model='model')
|
||||
registry.async_get_or_create(
|
||||
config_entry_id='1234',
|
||||
connections={},
|
||||
identifiers={('bridgeid', '1234')},
|
||||
manufacturer='manufacturer', model='model',
|
||||
via_hub=('bridgeid', '0123'))
|
||||
|
|
|
@ -17,7 +17,10 @@ async def test_get_or_create_returns_same_entry(registry):
|
|||
config_entry_id='1234',
|
||||
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
|
||||
identifiers={('bridgeid', '0123')},
|
||||
manufacturer='manufacturer', model='model')
|
||||
sw_version='sw-version',
|
||||
name='name',
|
||||
manufacturer='manufacturer',
|
||||
model='model')
|
||||
entry2 = registry.async_get_or_create(
|
||||
config_entry_id='1234',
|
||||
connections={('ethernet', '11:22:33:44:55:66:77:88')},
|
||||
|
@ -25,15 +28,19 @@ async def test_get_or_create_returns_same_entry(registry):
|
|||
manufacturer='manufacturer', model='model')
|
||||
entry3 = registry.async_get_or_create(
|
||||
config_entry_id='1234',
|
||||
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
|
||||
identifiers={('bridgeid', '1234')},
|
||||
manufacturer='manufacturer', model='model')
|
||||
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}
|
||||
)
|
||||
|
||||
assert len(registry.devices) == 1
|
||||
assert entry.id == entry2.id
|
||||
assert entry.id == entry3.id
|
||||
assert entry.identifiers == {('bridgeid', '0123')}
|
||||
|
||||
assert entry3.manufacturer == 'manufacturer'
|
||||
assert entry3.model == 'model'
|
||||
assert entry3.name == 'name'
|
||||
assert entry3.sw_version == 'sw-version'
|
||||
|
||||
|
||||
async def test_requirement_for_identifier_or_connection(registry):
|
||||
"""Make sure we do require some descriptor of device."""
|
||||
|
|
Loading…
Reference in New Issue