Complete device and entity registry type hints (#44406)

pull/44835/head
Ville Skyttä 2021-01-05 03:03:16 +02:00 committed by GitHub
parent d315ab2cf5
commit 65e56d03bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 136 deletions

View File

@ -11,13 +11,11 @@ import homeassistant.util.uuid as uuid_util
from .debounce import Debouncer
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING:
from . import entity_registry
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
DATA_REGISTRY = "device_registry"
@ -40,26 +38,6 @@ DISABLED_INTEGRATION = "integration"
DISABLED_USER = "user"
@attr.s(slots=True, frozen=True)
class DeletedDeviceEntry:
"""Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib()
id: str = attr.ib()
def to_device_entry(self, config_entry_id, connections, identifiers):
"""Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry(
config_entries={config_entry_id},
connections=self.connections & connections,
identifiers=self.identifiers & identifiers,
id=self.id,
is_new=True,
)
@attr.s(slots=True, frozen=True)
class DeviceEntry:
"""Device Registry Entry."""
@ -67,14 +45,14 @@ class DeviceEntry:
config_entries: Set[str] = attr.ib(converter=set, factory=set)
connections: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
identifiers: Set[Tuple[str, str]] = attr.ib(converter=set, factory=set)
manufacturer: str = attr.ib(default=None)
model: str = attr.ib(default=None)
name: str = attr.ib(default=None)
sw_version: str = attr.ib(default=None)
via_device_id: str = attr.ib(default=None)
area_id: str = attr.ib(default=None)
name_by_user: str = attr.ib(default=None)
entry_type: str = attr.ib(default=None)
manufacturer: Optional[str] = attr.ib(default=None)
model: Optional[str] = attr.ib(default=None)
name: Optional[str] = attr.ib(default=None)
sw_version: Optional[str] = attr.ib(default=None)
via_device_id: Optional[str] = attr.ib(default=None)
area_id: Optional[str] = attr.ib(default=None)
name_by_user: Optional[str] = attr.ib(default=None)
entry_type: Optional[str] = attr.ib(default=None)
id: str = attr.ib(factory=uuid_util.random_uuid_hex)
# This value is not stored, just used to keep track of events to fire.
is_new: bool = attr.ib(default=False)
@ -95,6 +73,32 @@ class DeviceEntry:
return self.disabled_by is not None
@attr.s(slots=True, frozen=True)
class DeletedDeviceEntry:
"""Deleted Device Registry Entry."""
config_entries: Set[str] = attr.ib()
connections: Set[Tuple[str, str]] = attr.ib()
identifiers: Set[Tuple[str, str]] = attr.ib()
id: str = attr.ib()
def to_device_entry(
self,
config_entry_id: str,
connections: Set[Tuple[str, str]],
identifiers: Set[Tuple[str, str]],
) -> DeviceEntry:
"""Create DeviceEntry from DeletedDeviceEntry."""
return DeviceEntry(
# type ignores: likely https://github.com/python/mypy/issues/8625
config_entries={config_entry_id}, # type: ignore[arg-type]
connections=self.connections & connections, # type: ignore[arg-type]
identifiers=self.identifiers & identifiers, # type: ignore[arg-type]
id=self.id,
is_new=True,
)
def format_mac(mac: str) -> str:
"""Format the mac address string for entry into dev reg."""
to_test = mac
@ -201,40 +205,40 @@ class DeviceRegistry:
_remove_device_from_index(devices_index, old_device)
_add_device_to_index(devices_index, new_device)
def _clear_index(self):
def _clear_index(self) -> None:
"""Clear the index."""
self._devices_index = {
REGISTERED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
DELETED_DEVICE: {IDX_IDENTIFIERS: {}, IDX_CONNECTIONS: {}},
}
def _rebuild_index(self):
def _rebuild_index(self) -> None:
"""Create the index after loading devices."""
self._clear_index()
for device in self.devices.values():
_add_device_to_index(self._devices_index[REGISTERED_DEVICE], device)
for device in self.deleted_devices.values():
_add_device_to_index(self._devices_index[DELETED_DEVICE], device)
for deleted_device in self.deleted_devices.values():
_add_device_to_index(self._devices_index[DELETED_DEVICE], deleted_device)
@callback
def async_get_or_create(
self,
*,
config_entry_id,
connections=None,
identifiers=None,
manufacturer=UNDEFINED,
model=UNDEFINED,
name=UNDEFINED,
default_manufacturer=UNDEFINED,
default_model=UNDEFINED,
default_name=UNDEFINED,
sw_version=UNDEFINED,
entry_type=UNDEFINED,
via_device=None,
config_entry_id: str,
connections: Optional[set] = None,
identifiers: Optional[set] = None,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
default_manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
default_model: Union[str, None, UndefinedType] = UNDEFINED,
default_name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device: Optional[str] = None,
# To disable a device if it gets created
disabled_by=UNDEFINED,
):
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
@ -271,7 +275,7 @@ class DeviceRegistry:
if via_device is not None:
via = self.async_get_device({via_device}, set())
via_device_id = via.id if via else UNDEFINED
via_device_id: Union[str, UndefinedType] = via.id if via else UNDEFINED
else:
via_device_id = UNDEFINED
@ -292,19 +296,19 @@ class DeviceRegistry:
@callback
def async_update_device(
self,
device_id,
device_id: str,
*,
area_id=UNDEFINED,
manufacturer=UNDEFINED,
model=UNDEFINED,
name=UNDEFINED,
name_by_user=UNDEFINED,
new_identifiers=UNDEFINED,
sw_version=UNDEFINED,
via_device_id=UNDEFINED,
remove_config_entry_id=UNDEFINED,
disabled_by=UNDEFINED,
):
area_id: Union[str, None, UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Update properties of a device."""
return self._async_update_device(
device_id,
@ -323,27 +327,27 @@ class DeviceRegistry:
@callback
def _async_update_device(
self,
device_id,
device_id: str,
*,
add_config_entry_id=UNDEFINED,
remove_config_entry_id=UNDEFINED,
merge_connections=UNDEFINED,
merge_identifiers=UNDEFINED,
new_identifiers=UNDEFINED,
manufacturer=UNDEFINED,
model=UNDEFINED,
name=UNDEFINED,
sw_version=UNDEFINED,
entry_type=UNDEFINED,
via_device_id=UNDEFINED,
area_id=UNDEFINED,
name_by_user=UNDEFINED,
disabled_by=UNDEFINED,
):
add_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
remove_config_entry_id: Union[str, UndefinedType] = UNDEFINED,
merge_connections: Union[set, UndefinedType] = UNDEFINED,
merge_identifiers: Union[set, UndefinedType] = UNDEFINED,
new_identifiers: Union[set, UndefinedType] = UNDEFINED,
manufacturer: Union[str, None, UndefinedType] = UNDEFINED,
model: Union[str, None, UndefinedType] = UNDEFINED,
name: Union[str, None, UndefinedType] = UNDEFINED,
sw_version: Union[str, None, UndefinedType] = UNDEFINED,
entry_type: Union[str, None, UndefinedType] = UNDEFINED,
via_device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
name_by_user: Union[str, None, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> Optional[DeviceEntry]:
"""Update device attributes."""
old = self.devices[device_id]
changes = {}
changes: Dict[str, Any] = {}
config_entries = old.config_entries
@ -359,21 +363,21 @@ class DeviceRegistry:
):
if config_entries == {remove_config_entry_id}:
self.async_remove_device(device_id)
return
return None
config_entries = config_entries - {remove_config_entry_id}
if config_entries != old.config_entries:
changes["config_entries"] = config_entries
for attr_name, value in (
for attr_name, setvalue in (
("connections", merge_connections),
("identifiers", merge_identifiers),
):
old_value = getattr(old, attr_name)
# If not undefined, check if `value` contains new items.
if value is not UNDEFINED and not value.issubset(old_value):
changes[attr_name] = old_value | value
if setvalue is not UNDEFINED and not setvalue.issubset(old_value):
changes[attr_name] = old_value | setvalue
if new_identifiers is not UNDEFINED:
changes["identifiers"] = new_identifiers
@ -434,7 +438,7 @@ class DeviceRegistry:
)
self.async_schedule_save()
async def async_load(self):
async def async_load(self) -> None:
"""Load the device registry."""
async_setup_cleanup(self.hass, self)
@ -447,8 +451,9 @@ class DeviceRegistry:
for device in data["devices"]:
devices[device["id"]] = DeviceEntry(
config_entries=set(device["config_entries"]),
connections={tuple(conn) for conn in device["connections"]},
identifiers={tuple(iden) for iden in device["identifiers"]},
# type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc]
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc]
manufacturer=device["manufacturer"],
model=device["model"],
name=device["name"],
@ -471,8 +476,9 @@ class DeviceRegistry:
for device in data.get("deleted_devices", []):
deleted_devices[device["id"]] = DeletedDeviceEntry(
config_entries=set(device["config_entries"]),
connections={tuple(conn) for conn in device["connections"]},
identifiers={tuple(iden) for iden in device["identifiers"]},
# type ignores (if tuple arg was cast): likely https://github.com/python/mypy/issues/8625
connections={tuple(conn) for conn in device["connections"]}, # type: ignore[misc]
identifiers={tuple(iden) for iden in device["identifiers"]}, # type: ignore[misc]
id=device["id"],
)
@ -614,7 +620,7 @@ def async_setup_cleanup(hass: HomeAssistantType, dev_reg: DeviceRegistry) -> Non
"""Clean up device registry when entities removed."""
from . import entity_registry # pylint: disable=import-outside-toplevel
async def cleanup():
async def cleanup() -> None:
"""Cleanup."""
ent_reg = await entity_registry.async_get_registry(hass)
async_cleanup(hass, dev_reg, ent_reg)

View File

@ -18,7 +18,7 @@ from typing import (
List,
Optional,
Tuple,
cast,
Union,
)
import attr
@ -39,13 +39,11 @@ from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401
# mypy: allow-untyped-defs, no-check-untyped-defs
PATH_REGISTRY = "entity_registry.yaml"
DATA_REGISTRY = "entity_registry"
EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated"
@ -222,7 +220,7 @@ class EntityRegistry:
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
return self._async_update_entity( # type: ignore
return self._async_update_entity(
entity_id,
config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or UNDEFINED,
@ -316,63 +314,56 @@ class EntityRegistry:
for entity in entities:
if entity.disabled_by != DISABLED_DEVICE:
continue
self.async_update_entity( # type: ignore
entity.entity_id, disabled_by=None
)
self.async_update_entity(entity.entity_id, disabled_by=None)
return
entities = async_entries_for_device(self, event.data["device_id"])
for entity in entities:
self.async_update_entity( # type: ignore
entity.entity_id, disabled_by=DISABLED_DEVICE
)
self.async_update_entity(entity.entity_id, disabled_by=DISABLED_DEVICE)
@callback
def async_update_entity(
self,
entity_id,
entity_id: str,
*,
name=UNDEFINED,
icon=UNDEFINED,
area_id=UNDEFINED,
new_entity_id=UNDEFINED,
new_unique_id=UNDEFINED,
disabled_by=UNDEFINED,
):
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
) -> RegistryEntry:
"""Update properties of an entity."""
return cast( # cast until we have _async_update_entity type hinted
RegistryEntry,
self._async_update_entity(
entity_id,
name=name,
icon=icon,
area_id=area_id,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
),
return self._async_update_entity(
entity_id,
name=name,
icon=icon,
area_id=area_id,
new_entity_id=new_entity_id,
new_unique_id=new_unique_id,
disabled_by=disabled_by,
)
@callback
def _async_update_entity(
self,
entity_id,
entity_id: str,
*,
name=UNDEFINED,
icon=UNDEFINED,
config_entry_id=UNDEFINED,
new_entity_id=UNDEFINED,
device_id=UNDEFINED,
area_id=UNDEFINED,
new_unique_id=UNDEFINED,
disabled_by=UNDEFINED,
capabilities=UNDEFINED,
supported_features=UNDEFINED,
device_class=UNDEFINED,
unit_of_measurement=UNDEFINED,
original_name=UNDEFINED,
original_icon=UNDEFINED,
):
name: Union[str, None, UndefinedType] = UNDEFINED,
icon: Union[str, None, UndefinedType] = UNDEFINED,
config_entry_id: Union[str, None, UndefinedType] = UNDEFINED,
new_entity_id: Union[str, UndefinedType] = UNDEFINED,
device_id: Union[str, None, UndefinedType] = UNDEFINED,
area_id: Union[str, None, UndefinedType] = UNDEFINED,
new_unique_id: Union[str, UndefinedType] = UNDEFINED,
disabled_by: Union[str, None, UndefinedType] = UNDEFINED,
capabilities: Union[Dict[str, Any], None, UndefinedType] = UNDEFINED,
supported_features: Union[int, UndefinedType] = UNDEFINED,
device_class: Union[str, None, UndefinedType] = UNDEFINED,
unit_of_measurement: Union[str, None, UndefinedType] = UNDEFINED,
original_name: Union[str, None, UndefinedType] = UNDEFINED,
original_icon: Union[str, None, UndefinedType] = UNDEFINED,
) -> RegistryEntry:
"""Private facing update properties method."""
old = self.entities[entity_id]
@ -526,7 +517,7 @@ class EntityRegistry:
"""Clear area id from registry entries."""
for entity_id, entry in self.entities.items():
if area_id == entry.area_id:
self._async_update_entity(entity_id, area_id=None) # type: ignore
self._async_update_entity(entity_id, area_id=None)
def _register_entry(self, entry: RegistryEntry) -> None:
self.entities[entry.entity_id] = entry