From 65e56d03bf4ebf3adc6329cd10e6fe48db83e426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Tue, 5 Jan 2021 03:03:16 +0200 Subject: [PATCH] Complete device and entity registry type hints (#44406) --- homeassistant/helpers/device_registry.py | 184 ++++++++++++----------- homeassistant/helpers/entity_registry.py | 85 +++++------ 2 files changed, 133 insertions(+), 136 deletions(-) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index 6e8c09bbd60..a115434fad9 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -11,13 +11,11 @@ import homeassistant.util.uuid as uuid_util from .debounce import Debouncer from .singleton import singleton -from .typing import UNDEFINED, HomeAssistantType +from .typing import UNDEFINED, HomeAssistantType, UndefinedType if TYPE_CHECKING: from . import entity_registry -# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs - _LOGGER = logging.getLogger(__name__) DATA_REGISTRY = "device_registry" @@ -40,26 +38,6 @@ DISABLED_INTEGRATION = "integration" DISABLED_USER = "user" -@attr.s(slots=True, frozen=True) -class DeletedDeviceEntry: - """Deleted Device Registry Entry.""" - - config_entries: Set[str] = attr.ib() - connections: Set[Tuple[str, str]] = attr.ib() - identifiers: Set[Tuple[str, str]] = attr.ib() - id: str = attr.ib() - - def to_device_entry(self, config_entry_id, connections, identifiers): - """Create DeviceEntry from DeletedDeviceEntry.""" - return DeviceEntry( - config_entries={config_entry_id}, - connections=self.connections & connections, - identifiers=self.identifiers & identifiers, - id=self.id, - is_new=True, - ) - - @attr.s(slots=True, frozen=True) class DeviceEntry: """Device Registry Entry.""" @@ -67,14 +45,14 @@ class DeviceEntry: config_entries: Set[str] = attr.ib(converter=set, factory=set) connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set) - manufacturer: str = attr.ib(default=None) - model: str = attr.ib(default=None) - name: str = attr.ib(default=None) - sw_version: str = attr.ib(default=None) - via_device_id: str = attr.ib(default=None) - area_id: str = attr.ib(default=None) - name_by_user: str = attr.ib(default=None) - entry_type: str = attr.ib(default=None) + manufacturer: Optional[str] = attr.ib(default=None) + model: Optional[str] = attr.ib(default=None) + name: Optional[str] = attr.ib(default=None) + sw_version: Optional[str] = attr.ib(default=None) + via_device_id: Optional[str] = attr.ib(default=None) + area_id: Optional[str] = attr.ib(default=None) + name_by_user: Optional[str] = attr.ib(default=None) + entry_type: Optional[str] = attr.ib(default=None) id: str = attr.ib(factory=uuid_util.random_uuid_hex) # This value is not stored, just used to keep track of events to fire. is_new: bool = attr.ib(default=False) @@ -95,6 +73,32 @@ class DeviceEntry: return self.disabled_by is not None +@attr.s(slots=True, frozen=True) +class DeletedDeviceEntry: + """Deleted Device Registry Entry.""" + + config_entries: Set[str] = attr.ib() + connections: Set[Tuple[str, str]] = attr.ib() + identifiers: Set[Tuple[str, str]] = attr.ib() + id: str = attr.ib() + + def to_device_entry( + self, + config_entry_id: str, + connections: Set[Tuple[str, str]], + identifiers: Set[Tuple[str, str]], + ) -> DeviceEntry: + """Create DeviceEntry from DeletedDeviceEntry.""" + return DeviceEntry( + # type ignores: likely https://github.com/python/mypy/issues/8625 + config_entries={config_entry_id}, # type: ignore[arg-type] + connections=self.connections & connections, # type: ignore[arg-type] + identifiers=self.identifiers & identifiers, # type: ignore[arg-type] + id=self.id, + is_new=True, + ) + + def format_mac(mac: str) -> str: """Format the mac address string for entry into dev reg.""" to_test = mac @@ -201,40 +205,40 @@ class DeviceRegistry: _remove_device_from_index(devices_index, old_device) _add_device_to_index(devices_index, new_device) - def _clear_index(self): + def _clear_index(self) -> None: """Clear the index.""" self._devices_index = { REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}}, } - def _rebuild_index(self): + def _rebuild_index(self) -> None: """Create the index after loading devices.""" self._clear_index() for device in self.devices.values(): _add_device_to_index(self._devices_index[REGISTERED_DEVICE], device) - for device in self.deleted_devices.values(): - _add_device_to_index(self._devices_index[DELETED_DEVICE], device) + for deleted_device in self.deleted_devices.values(): + _add_device_to_index(self._devices_index[DELETED_DEVICE], deleted_device) @callback def async_get_or_create( self, *, - config_entry_id, - connections=None, - identifiers=None, - manufacturer=UNDEFINED, - model=UNDEFINED, - name=UNDEFINED, - default_manufacturer=UNDEFINED, - default_model=UNDEFINED, - default_name=UNDEFINED, - sw_version=UNDEFINED, - entry_type=UNDEFINED, - via_device=None, + config_entry_id: str, + connections: Optional[set] = None, + identifiers: Optional[set] = None, + manufacturer: Union[str, None, UndefinedType] = UNDEFINED, + model: Union[str, None, UndefinedType] = UNDEFINED, + name: Union[str, None, UndefinedType] = UNDEFINED, + default_manufacturer: Union[str, None, UndefinedType] = UNDEFINED, + default_model: Union[str, None, UndefinedType] = UNDEFINED, + default_name: Union[str, None, UndefinedType] = UNDEFINED, + sw_version: Union[str, None, UndefinedType] = UNDEFINED, + entry_type: Union[str, None, UndefinedType] = UNDEFINED, + via_device: Optional[str] = None, # To disable a device if it gets created - disabled_by=UNDEFINED, - ): + disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + ) -> Optional[DeviceEntry]: """Get device. Create if it doesn't exist.""" if not identifiers and not connections: return None @@ -271,7 +275,7 @@ class DeviceRegistry: if via_device is not None: via = self.async_get_device({via_device}, set()) - via_device_id = via.id if via else UNDEFINED + via_device_id: Union[str, UndefinedType] = via.id if via else UNDEFINED else: via_device_id = UNDEFINED @@ -292,19 +296,19 @@ class DeviceRegistry: @callback def async_update_device( self, - device_id, + device_id: str, *, - area_id=UNDEFINED, - manufacturer=UNDEFINED, - model=UNDEFINED, - name=UNDEFINED, - name_by_user=UNDEFINED, - new_identifiers=UNDEFINED, - sw_version=UNDEFINED, - via_device_id=UNDEFINED, - remove_config_entry_id=UNDEFINED, - disabled_by=UNDEFINED, - ): + area_id: Union[str, None, UndefinedType] = UNDEFINED, + manufacturer: Union[str, None, UndefinedType] = UNDEFINED, + model: Union[str, None, UndefinedType] = UNDEFINED, + name: Union[str, None, UndefinedType] = UNDEFINED, + name_by_user: Union[str, None, UndefinedType] = UNDEFINED, + new_identifiers: Union[set, UndefinedType] = UNDEFINED, + sw_version: Union[str, None, UndefinedType] = UNDEFINED, + via_device_id: Union[str, None, UndefinedType] = UNDEFINED, + remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, + disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + ) -> Optional[DeviceEntry]: """Update properties of a device.""" return self._async_update_device( device_id, @@ -323,27 +327,27 @@ class DeviceRegistry: @callback def _async_update_device( self, - device_id, + device_id: str, *, - add_config_entry_id=UNDEFINED, - remove_config_entry_id=UNDEFINED, - merge_connections=UNDEFINED, - merge_identifiers=UNDEFINED, - new_identifiers=UNDEFINED, - manufacturer=UNDEFINED, - model=UNDEFINED, - name=UNDEFINED, - sw_version=UNDEFINED, - entry_type=UNDEFINED, - via_device_id=UNDEFINED, - area_id=UNDEFINED, - name_by_user=UNDEFINED, - disabled_by=UNDEFINED, - ): + add_config_entry_id: Union[str, UndefinedType] = UNDEFINED, + remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED, + merge_connections: Union[set, UndefinedType] = UNDEFINED, + merge_identifiers: Union[set, UndefinedType] = UNDEFINED, + new_identifiers: Union[set, UndefinedType] = UNDEFINED, + manufacturer: Union[str, None, UndefinedType] = UNDEFINED, + model: Union[str, None, UndefinedType] = UNDEFINED, + name: Union[str, None, UndefinedType] = UNDEFINED, + sw_version: Union[str, None, UndefinedType] = UNDEFINED, + entry_type: Union[str, None, UndefinedType] = UNDEFINED, + via_device_id: Union[str, None, UndefinedType] = UNDEFINED, + area_id: Union[str, None, UndefinedType] = UNDEFINED, + name_by_user: Union[str, None, UndefinedType] = UNDEFINED, + disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + ) -> Optional[DeviceEntry]: """Update device attributes.""" old = self.devices[device_id] - changes = {} + changes: Dict[str, Any] = {} config_entries = old.config_entries @@ -359,21 +363,21 @@ class DeviceRegistry: ): if config_entries == {remove_config_entry_id}: self.async_remove_device(device_id) - return + return None config_entries = config_entries - {remove_config_entry_id} if config_entries != old.config_entries: changes["config_entries"] = config_entries - for attr_name, value in ( + for attr_name, setvalue in ( ("connections", merge_connections), ("identifiers", merge_identifiers), ): old_value = getattr(old, attr_name) # If not undefined, check if `value` contains new items. - if value is not UNDEFINED and not value.issubset(old_value): - changes[attr_name] = old_value | value + if setvalue is not UNDEFINED and not setvalue.issubset(old_value): + changes[attr_name] = old_value | setvalue if new_identifiers is not UNDEFINED: changes["identifiers"] = new_identifiers @@ -434,7 +438,7 @@ class DeviceRegistry: ) self.async_schedule_save() - async def async_load(self): + async def async_load(self) -> None: """Load the device registry.""" async_setup_cleanup(self.hass, self) @@ -447,8 +451,9 @@ class DeviceRegistry: for device in data["devices"]: devices[device["id"]] = DeviceEntry( config_entries=set(device["config_entries"]), - connections={tuple(conn) for conn in device["connections"]}, - identifiers={tuple(iden) for iden in device["identifiers"]}, + # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625 + connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc] + identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc] manufacturer=device["manufacturer"], model=device["model"], name=device["name"], @@ -471,8 +476,9 @@ class DeviceRegistry: for device in data.get("deleted_devices", []): deleted_devices[device["id"]] = DeletedDeviceEntry( config_entries=set(device["config_entries"]), - connections={tuple(conn) for conn in device["connections"]}, - identifiers={tuple(iden) for iden in device["identifiers"]}, + # type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625 + connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc] + identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc] id=device["id"], ) @@ -614,7 +620,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non """Clean up device registry when entities removed.""" from . import entity_registry # pylint: disable=import-outside-toplevel - async def cleanup(): + async def cleanup() -> None: """Cleanup.""" ent_reg = await entity_registry.async_get_registry(hass) async_cleanup(hass, dev_reg, ent_reg) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 44f5c9c56f7..95497f7179c 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -18,7 +18,7 @@ from typing import ( List, Optional, Tuple, - cast, + Union, ) import attr @@ -39,13 +39,11 @@ from homeassistant.util import slugify from homeassistant.util.yaml import load_yaml from .singleton import singleton -from .typing import UNDEFINED, HomeAssistantType +from .typing import UNDEFINED, HomeAssistantType, UndefinedType if TYPE_CHECKING: from homeassistant.config_entries import ConfigEntry # noqa: F401 -# mypy: allow-untyped-defs, no-check-untyped-defs - PATH_REGISTRY = "entity_registry.yaml" DATA_REGISTRY = "entity_registry" EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated" @@ -222,7 +220,7 @@ class EntityRegistry: entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: - return self._async_update_entity( # type: ignore + return self._async_update_entity( entity_id, config_entry_id=config_entry_id or UNDEFINED, device_id=device_id or UNDEFINED, @@ -316,63 +314,56 @@ class EntityRegistry: for entity in entities: if entity.disabled_by != DISABLED_DEVICE: continue - self.async_update_entity( # type: ignore - entity.entity_id, disabled_by=None - ) + self.async_update_entity(entity.entity_id, disabled_by=None) return entities = async_entries_for_device(self, event.data["device_id"]) for entity in entities: - self.async_update_entity( # type: ignore - entity.entity_id, disabled_by=DISABLED_DEVICE - ) + self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE) @callback def async_update_entity( self, - entity_id, + entity_id: str, *, - name=UNDEFINED, - icon=UNDEFINED, - area_id=UNDEFINED, - new_entity_id=UNDEFINED, - new_unique_id=UNDEFINED, - disabled_by=UNDEFINED, - ): + name: Union[str, None, UndefinedType] = UNDEFINED, + icon: Union[str, None, UndefinedType] = UNDEFINED, + area_id: Union[str, None, UndefinedType] = UNDEFINED, + new_entity_id: Union[str, UndefinedType] = UNDEFINED, + new_unique_id: Union[str, UndefinedType] = UNDEFINED, + disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + ) -> RegistryEntry: """Update properties of an entity.""" - return cast( # cast until we have _async_update_entity type hinted - RegistryEntry, - self._async_update_entity( - entity_id, - name=name, - icon=icon, - area_id=area_id, - new_entity_id=new_entity_id, - new_unique_id=new_unique_id, - disabled_by=disabled_by, - ), + return self._async_update_entity( + entity_id, + name=name, + icon=icon, + area_id=area_id, + new_entity_id=new_entity_id, + new_unique_id=new_unique_id, + disabled_by=disabled_by, ) @callback def _async_update_entity( self, - entity_id, + entity_id: str, *, - name=UNDEFINED, - icon=UNDEFINED, - config_entry_id=UNDEFINED, - new_entity_id=UNDEFINED, - device_id=UNDEFINED, - area_id=UNDEFINED, - new_unique_id=UNDEFINED, - disabled_by=UNDEFINED, - capabilities=UNDEFINED, - supported_features=UNDEFINED, - device_class=UNDEFINED, - unit_of_measurement=UNDEFINED, - original_name=UNDEFINED, - original_icon=UNDEFINED, - ): + name: Union[str, None, UndefinedType] = UNDEFINED, + icon: Union[str, None, UndefinedType] = UNDEFINED, + config_entry_id: Union[str, None, UndefinedType] = UNDEFINED, + new_entity_id: Union[str, UndefinedType] = UNDEFINED, + device_id: Union[str, None, UndefinedType] = UNDEFINED, + area_id: Union[str, None, UndefinedType] = UNDEFINED, + new_unique_id: Union[str, UndefinedType] = UNDEFINED, + disabled_by: Union[str, None, UndefinedType] = UNDEFINED, + capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED, + supported_features: Union[int, UndefinedType] = UNDEFINED, + device_class: Union[str, None, UndefinedType] = UNDEFINED, + unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED, + original_name: Union[str, None, UndefinedType] = UNDEFINED, + original_icon: Union[str, None, UndefinedType] = UNDEFINED, + ) -> RegistryEntry: """Private facing update properties method.""" old = self.entities[entity_id] @@ -526,7 +517,7 @@ class EntityRegistry: """Clear area id from registry entries.""" for entity_id, entry in self.entities.items(): if area_id == entry.area_id: - self._async_update_entity(entity_id, area_id=None) # type: ignore + self._async_update_entity(entity_id, area_id=None) def _register_entry(self, entry: RegistryEntry) -> None: self.entities[entry.entity_id] = entry