From b4afb1cb6b496867d39dd679acaee07f571f5b9b Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Thu, 15 Sep 2022 11:53:00 +0200 Subject: [PATCH] Make use of generic EntityComponent (#78492) --- .../components/automation/__init__.py | 4 ++-- homeassistant/components/camera/__init__.py | 19 ++++++------------- .../components/camera/media_source.py | 9 +++------ homeassistant/components/group/__init__.py | 18 +++++++----------- homeassistant/components/person/__init__.py | 11 +++++------ homeassistant/components/remote/__init__.py | 10 ++++++---- homeassistant/components/script/__init__.py | 6 +++--- 7 files changed, 32 insertions(+), 45 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index b4fc82f52b9..841f015fa0f 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -221,7 +221,7 @@ def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent[AutomationEntity] = hass.data[DOMAIN] return [ automation_entity.entity_id @@ -661,7 +661,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity): async def _async_process_config( hass: HomeAssistant, config: dict[str, Any], - component: EntityComponent, + component: EntityComponent[AutomationEntity], ) -> bool: """Process config and add automations. diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 6e2b36070ae..fa807dd1440 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -322,12 +322,9 @@ def async_register_rtsp_to_web_rtc_provider( async def _async_refresh_providers(hass: HomeAssistant) -> None: """Check all cameras for any state changes for registered providers.""" - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[Camera] = hass.data[DOMAIN] await asyncio.gather( - *( - cast(Camera, camera).async_refresh_providers() - for camera in component.entities - ) + *(camera.async_refresh_providers() for camera in component.entities) ) @@ -343,7 +340,7 @@ def _async_get_rtsp_to_web_rtc_providers( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the camera component.""" - component = hass.data[DOMAIN] = EntityComponent( + component = hass.data[DOMAIN] = EntityComponent[Camera]( _LOGGER, DOMAIN, hass, SCAN_INTERVAL ) @@ -363,7 +360,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def preload_stream(_event: Event) -> None: for camera in component.entities: - camera = cast(Camera, camera) camera_prefs = prefs.get(camera.entity_id) if not camera_prefs.preload_stream: continue @@ -380,7 +376,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: def update_tokens(time: datetime) -> None: """Update tokens of the entities.""" for entity in component.entities: - entity = cast(Camera, entity) entity.async_update_token() entity.async_write_ha_state() @@ -411,13 +406,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[Camera] = hass.data[DOMAIN] return await component.async_setup_entry(entry) async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[Camera] = hass.data[DOMAIN] return await component.async_unload_entry(entry) @@ -698,7 +693,7 @@ class CameraView(HomeAssistantView): requires_auth = False - def __init__(self, component: EntityComponent) -> None: + def __init__(self, component: EntityComponent[Camera]) -> None: """Initialize a basic camera view.""" self.component = component @@ -707,8 +702,6 @@ class CameraView(HomeAssistantView): if (camera := self.component.get_entity(entity_id)) is None: raise web.HTTPNotFound() - camera = cast(Camera, camera) - authenticated = ( request[KEY_AUTHENTICATED] or request.query.get("token") in camera.access_tokens diff --git a/homeassistant/components/camera/media_source.py b/homeassistant/components/camera/media_source.py index 117f65edb07..e386e864ded 100644 --- a/homeassistant/components/camera/media_source.py +++ b/homeassistant/components/camera/media_source.py @@ -1,8 +1,6 @@ """Expose cameras as media sources.""" from __future__ import annotations -from typing import Optional, cast - from homeassistant.components.media_player import BrowseError, MediaClass from homeassistant.components.media_source.error import Unresolvable from homeassistant.components.media_source.models import ( @@ -37,8 +35,8 @@ class CameraMediaSource(MediaSource): async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: """Resolve media to a url.""" - component: EntityComponent = self.hass.data[DOMAIN] - camera = cast(Optional[Camera], component.get_entity(item.identifier)) + component: EntityComponent[Camera] = self.hass.data[DOMAIN] + camera = component.get_entity(item.identifier) if not camera: raise Unresolvable(f"Could not resolve media item: {item.identifier}") @@ -72,11 +70,10 @@ class CameraMediaSource(MediaSource): can_stream_hls = "stream" in self.hass.config.components # Root. List cameras. - component: EntityComponent = self.hass.data[DOMAIN] + component: EntityComponent[Camera] = self.hass.data[DOMAIN] children = [] not_shown = 0 for camera in component.entities: - camera = cast(Camera, camera) stream_type = camera.frontend_stream_type if stream_type is None: diff --git a/homeassistant/components/group/__init__.py b/homeassistant/components/group/__init__.py index c1759432ade..bdf295e35ae 100644 --- a/homeassistant/components/group/__init__.py +++ b/homeassistant/components/group/__init__.py @@ -6,7 +6,7 @@ import asyncio from collections.abc import Collection, Iterable from contextvars import ContextVar import logging -from typing import Any, Protocol, Union, cast +from typing import Any, Protocol, cast import voluptuous as vol @@ -119,9 +119,9 @@ CONFIG_SCHEMA = vol.Schema( ) -def _async_get_component(hass: HomeAssistant) -> EntityComponent: +def _async_get_component(hass: HomeAssistant) -> EntityComponent[Group]: if (component := hass.data.get(DOMAIN)) is None: - component = hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass) + component = hass.data[DOMAIN] = EntityComponent[Group](_LOGGER, DOMAIN, hass) return component @@ -288,11 +288,11 @@ async def async_remove_entry(hass: HomeAssistant, entry: ConfigEntry) -> None: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up all groups found defined in the configuration.""" if DOMAIN not in hass.data: - hass.data[DOMAIN] = EntityComponent(_LOGGER, DOMAIN, hass) + hass.data[DOMAIN] = EntityComponent[Group](_LOGGER, DOMAIN, hass) await async_process_integration_platform_for_component(hass, DOMAIN) - component: EntityComponent = hass.data[DOMAIN] + component: EntityComponent[Group] = hass.data[DOMAIN] hass.data[REG_KEY] = GroupIntegrationRegistry() @@ -302,11 +302,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def reload_service_handler(service: ServiceCall) -> None: """Remove all user-defined groups and load new ones from config.""" - auto = [ - cast(Group, e) - for e in component.entities - if not cast(Group, e).user_defined - ] + auto = [e for e in component.entities if not e.user_defined] if (conf := await component.async_prepare_reload()) is None: return @@ -331,7 +327,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Handle dynamic group service functions.""" object_id = service.data[ATTR_OBJECT_ID] entity_id = f"{DOMAIN}.{object_id}" - group: Group | None = cast(Union[Group, None], component.get_entity(entity_id)) + group = component.get_entity(entity_id) # new group if service.service == SERVICE_SET and group is None: diff --git a/homeassistant/components/person/__init__.py b/homeassistant/components/person/__init__.py index 0823e9e4b55..86a132027d8 100644 --- a/homeassistant/components/person/__init__.py +++ b/homeassistant/components/person/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -from typing import cast import voluptuous as vol @@ -107,7 +106,7 @@ async def async_add_user_device_tracker( hass: HomeAssistant, user_id: str, device_tracker_entity_id: str ): """Add a device tracker to a person linked to a user.""" - coll = cast(PersonStorageCollection, hass.data[DOMAIN][1]) + coll: PersonStorageCollection = hass.data[DOMAIN][1] for person in coll.async_items(): if person.get(ATTR_USER_ID) != user_id: @@ -134,12 +133,12 @@ def persons_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]: ): return [] - component: EntityComponent = hass.data[DOMAIN][2] + component: EntityComponent[Person] = hass.data[DOMAIN][2] return [ person_entity.entity_id for person_entity in component.entities - if entity_id in cast(Person, person_entity).device_trackers + if entity_id in person_entity.device_trackers ] @@ -149,12 +148,12 @@ def entities_in_person(hass: HomeAssistant, entity_id: str) -> list[str]: if DOMAIN not in hass.data: return [] - component: EntityComponent = hass.data[DOMAIN][2] + component: EntityComponent[Person] = hass.data[DOMAIN][2] if (person_entity := component.get_entity(entity_id)) is None: return [] - return cast(Person, person_entity).device_trackers + return person_entity.device_trackers CREATE_FIELDS = { diff --git a/homeassistant/components/remote/__init__.py b/homeassistant/components/remote/__init__.py index b1b856cfa29..6ba5ca89d2d 100644 --- a/homeassistant/components/remote/__init__.py +++ b/homeassistant/components/remote/__init__.py @@ -7,7 +7,7 @@ from datetime import timedelta from enum import IntEnum import functools as ft import logging -from typing import Any, cast, final +from typing import Any, final import voluptuous as vol @@ -88,7 +88,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Track states and offer events for remotes.""" - component = hass.data[DOMAIN] = EntityComponent( + component = hass.data[DOMAIN] = EntityComponent[RemoteEntity]( _LOGGER, DOMAIN, hass, SCAN_INTERVAL ) await component.async_setup(config) @@ -145,12 +145,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up a config entry.""" - return await cast(EntityComponent, hass.data[DOMAIN]).async_setup_entry(entry) + component: EntityComponent[RemoteEntity] = hass.data[DOMAIN] + return await component.async_setup_entry(entry) async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - return await cast(EntityComponent, hass.data[DOMAIN]).async_unload_entry(entry) + component: EntityComponent[RemoteEntity] = hass.data[DOMAIN] + return await component.async_unload_entry(entry) @dataclass diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index 9791f0e588e..dce99620e44 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -86,7 +86,7 @@ def _scripts_with_x( if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent[ScriptEntity] = hass.data[DOMAIN] return [ script_entity.entity_id @@ -100,7 +100,7 @@ def _x_in_script(hass: HomeAssistant, entity_id: str, property_name: str) -> lis if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent[ScriptEntity] = hass.data[DOMAIN] if (script_entity := component.get_entity(entity_id)) is None: return [] @@ -150,7 +150,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str if DOMAIN not in hass.data: return [] - component = hass.data[DOMAIN] + component: EntityComponent[ScriptEntity] = hass.data[DOMAIN] return [ script_entity.entity_id