Add entity options to entity registry (#64350)

* Initial commit for entity options

* Tweak broadlink tests

* Add async_update_entity_options + test
pull/64437/head
Erik Montnemery 2022-01-18 22:47:46 +01:00 committed by GitHub
parent 57bcddbba2
commit a8c14835b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 11 deletions

View File

@ -57,7 +57,7 @@ SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STORAGE_VERSION_MAJOR = 1 STORAGE_VERSION_MAJOR = 1
STORAGE_VERSION_MINOR = 4 STORAGE_VERSION_MINOR = 5
STORAGE_KEY = "core.entity_registry" STORAGE_KEY = "core.entity_registry"
# Attributes relevant to describing entity # Attributes relevant to describing entity
@ -109,6 +109,9 @@ class RegistryEntry:
icon: str | None = attr.ib(default=None) icon: str | None = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex) id: str = attr.ib(factory=uuid_util.random_uuid_hex)
name: str | None = attr.ib(default=None) name: str | None = attr.ib(default=None)
options: Mapping[str, Mapping[str, Any]] = attr.ib(
default=None, converter=attr.converters.default_if_none(factory=dict) # type: ignore[misc]
)
# As set by integration # As set by integration
original_device_class: str | None = attr.ib(default=None) original_device_class: str | None = attr.ib(default=None)
original_icon: str | None = attr.ib(default=None) original_icon: str | None = attr.ib(default=None)
@ -560,6 +563,25 @@ class EntityRegistry:
return new return new
@callback
def async_update_entity_options(
self, entity_id: str, domain: str, options: dict[str, Any]
) -> None:
"""Update entity options."""
old = self.entities[entity_id]
new_options: Mapping[str, Mapping[str, Any]] = {**old.options, domain: options}
self.entities[entity_id] = attr.evolve(old, options=new_options)
self.async_schedule_save()
data: dict[str, str | dict[str, Any]] = {
"action": "update",
"entity_id": entity_id,
"changes": {"options": old.options},
}
self.hass.bus.async_fire(EVENT_ENTITY_REGISTRY_UPDATED, data)
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the entity registry.""" """Load the entity registry."""
async_setup_entity_restore(self.hass, self) async_setup_entity_restore(self.hass, self)
@ -595,6 +617,7 @@ class EntityRegistry:
icon=entity["icon"], icon=entity["icon"],
id=entity["id"], id=entity["id"],
name=entity["name"], name=entity["name"],
options=entity["options"],
original_device_class=entity["original_device_class"], original_device_class=entity["original_device_class"],
original_icon=entity["original_icon"], original_icon=entity["original_icon"],
original_name=entity["original_name"], original_name=entity["original_name"],
@ -629,6 +652,7 @@ class EntityRegistry:
"icon": entry.icon, "icon": entry.icon,
"id": entry.id, "id": entry.id,
"name": entry.name, "name": entry.name,
"options": entry.options,
"original_device_class": entry.original_device_class, "original_device_class": entry.original_device_class,
"original_icon": entry.original_icon, "original_icon": entry.original_icon,
"original_name": entry.original_name, "original_name": entry.original_name,
@ -749,7 +773,7 @@ async def _async_migrate(
old_major_version: int, old_minor_version: int, data: dict old_major_version: int, old_minor_version: int, data: dict
) -> dict: ) -> dict:
"""Migrate to the new version.""" """Migrate to the new version."""
if old_major_version < 2 and old_minor_version < 2: if old_major_version == 1 and old_minor_version < 2:
# From version 1.1 # From version 1.1
for entity in data["entities"]: for entity in data["entities"]:
# Populate all keys # Populate all keys
@ -768,18 +792,23 @@ async def _async_migrate(
entity["supported_features"] = entity.get("supported_features", 0) entity["supported_features"] = entity.get("supported_features", 0)
entity["unit_of_measurement"] = entity.get("unit_of_measurement") entity["unit_of_measurement"] = entity.get("unit_of_measurement")
if old_major_version < 2 and old_minor_version < 3: if old_major_version == 1 and old_minor_version < 3:
# Version 1.3 adds original_device_class # Version 1.3 adds original_device_class
for entity in data["entities"]: for entity in data["entities"]:
# Move device_class to original_device_class # Move device_class to original_device_class
entity["original_device_class"] = entity["device_class"] entity["original_device_class"] = entity["device_class"]
entity["device_class"] = None entity["device_class"] = None
if old_major_version < 2 and old_minor_version < 4: if old_major_version == 1 and old_minor_version < 4:
# Version 1.4 adds id # Version 1.4 adds id
for entity in data["entities"]: for entity in data["entities"]:
entity["id"] = uuid_util.random_uuid_hex() entity["id"] = uuid_util.random_uuid_hex()
if old_major_version == 1 and old_minor_version < 5:
# Version 1.5 adds entity options
for entity in data["entities"]:
entity["options"] = {}
if old_major_version > 1: if old_major_version > 1:
raise NotImplementedError raise NotImplementedError
return data return data

View File

@ -34,10 +34,10 @@ async def test_remote_setup_works(hass):
{(DOMAIN, mock_setup.entry.unique_id)} {(DOMAIN, mock_setup.entry.unique_id)}
) )
entries = async_entries_for_device(entity_registry, device_entry.id) entries = async_entries_for_device(entity_registry, device_entry.id)
remotes = {entry for entry in entries if entry.domain == Platform.REMOTE} remotes = [entry for entry in entries if entry.domain == Platform.REMOTE]
assert len(remotes) == 1 assert len(remotes) == 1
remote = remotes.pop() remote = remotes[0]
assert remote.original_name == f"{device.name} Remote" assert remote.original_name == f"{device.name} Remote"
assert hass.states.get(remote.entity_id).state == STATE_ON assert hass.states.get(remote.entity_id).state == STATE_ON
assert mock_setup.api.auth.call_count == 1 assert mock_setup.api.auth.call_count == 1
@ -54,10 +54,10 @@ async def test_remote_send_command(hass):
{(DOMAIN, mock_setup.entry.unique_id)} {(DOMAIN, mock_setup.entry.unique_id)}
) )
entries = async_entries_for_device(entity_registry, device_entry.id) entries = async_entries_for_device(entity_registry, device_entry.id)
remotes = {entry for entry in entries if entry.domain == Platform.REMOTE} remotes = [entry for entry in entries if entry.domain == Platform.REMOTE]
assert len(remotes) == 1 assert len(remotes) == 1
remote = remotes.pop() remote = remotes[0]
await hass.services.async_call( await hass.services.async_call(
Platform.REMOTE, Platform.REMOTE,
SERVICE_SEND_COMMAND, SERVICE_SEND_COMMAND,
@ -81,10 +81,10 @@ async def test_remote_turn_off_turn_on(hass):
{(DOMAIN, mock_setup.entry.unique_id)} {(DOMAIN, mock_setup.entry.unique_id)}
) )
entries = async_entries_for_device(entity_registry, device_entry.id) entries = async_entries_for_device(entity_registry, device_entry.id)
remotes = {entry for entry in entries if entry.domain == Platform.REMOTE} remotes = [entry for entry in entries if entry.domain == Platform.REMOTE]
assert len(remotes) == 1 assert len(remotes) == 1
remote = remotes.pop() remote = remotes[0]
await hass.services.async_call( await hass.services.async_call(
Platform.REMOTE, Platform.REMOTE,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,

View File

@ -196,12 +196,16 @@ async def test_loading_saving_data(hass, registry):
supported_features=5, supported_features=5,
unit_of_measurement="initial-unit_of_measurement", unit_of_measurement="initial-unit_of_measurement",
) )
orig_entry2 = registry.async_update_entity( registry.async_update_entity(
orig_entry2.entity_id, orig_entry2.entity_id,
device_class="user-class", device_class="user-class",
name="User Name", name="User Name",
icon="hass:user-icon", icon="hass:user-icon",
) )
registry.async_update_entity_options(
orig_entry2.entity_id, "light", {"minimum_brightness": 20}
)
orig_entry2 = registry.async_get(orig_entry2.entity_id)
assert len(registry.entities) == 2 assert len(registry.entities) == 2
@ -227,6 +231,7 @@ async def test_loading_saving_data(hass, registry):
assert new_entry2.entity_category == "config" assert new_entry2.entity_category == "config"
assert new_entry2.icon == "hass:user-icon" assert new_entry2.icon == "hass:user-icon"
assert new_entry2.name == "User Name" assert new_entry2.name == "User Name"
assert new_entry2.options == {"light": {"minimum_brightness": 20}}
assert new_entry2.original_device_class == "mock-device-class" assert new_entry2.original_device_class == "mock-device-class"
assert new_entry2.original_icon == "hass:original-icon" assert new_entry2.original_icon == "hass:original-icon"
assert new_entry2.original_name == "Original Name" assert new_entry2.original_name == "Original Name"
@ -570,6 +575,31 @@ async def test_update_entity(registry):
entry = updated_entry entry = updated_entry
async def test_update_entity_options(registry):
"""Test updating entity."""
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")
entry = registry.async_get_or_create(
"light", "hue", "5678", config_entry=mock_config
)
registry.async_update_entity_options(
entry.entity_id, "light", {"minimum_brightness": 20}
)
new_entry_1 = registry.async_get(entry.entity_id)
assert entry.options == {}
assert new_entry_1.options == {"light": {"minimum_brightness": 20}}
registry.async_update_entity_options(
entry.entity_id, "light", {"minimum_brightness": 30}
)
new_entry_2 = registry.async_get(entry.entity_id)
assert entry.options == {}
assert new_entry_1.options == {"light": {"minimum_brightness": 20}}
assert new_entry_2.options == {"light": {"minimum_brightness": 30}}
async def test_disabled_by(registry): async def test_disabled_by(registry):
"""Test that we can disable an entry when we create it.""" """Test that we can disable an entry when we create it."""
entry = registry.async_get_or_create( entry = registry.async_get_or_create(