From 0c5ca3084ee521b09180461b78dfc50efa13812a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Sat, 6 Jun 2020 21:34:56 +0300 Subject: [PATCH] Add and fix type hints (#36501) * Fix exceptions.Unauthorized.permission type * Use auth.permission consts more * Auth typing improvements * Helpers typing improvements * Calculate self.state only once --- homeassistant/auth/mfa_modules/totp.py | 3 +- homeassistant/auth/providers/__init__.py | 7 ++- .../components/config/config_entries.py | 8 +-- .../components/switcher_kis/__init__.py | 12 +--- homeassistant/core.py | 2 + homeassistant/exceptions.py | 4 +- homeassistant/helpers/entity.py | 56 +++++++++++-------- homeassistant/helpers/entity_platform.py | 2 +- homeassistant/helpers/service.py | 2 +- 9 files changed, 53 insertions(+), 43 deletions(-) diff --git a/homeassistant/auth/mfa_modules/totp.py b/homeassistant/auth/mfa_modules/totp.py index d35f237f424..2fc8c379861 100644 --- a/homeassistant/auth/mfa_modules/totp.py +++ b/homeassistant/auth/mfa_modules/totp.py @@ -117,7 +117,8 @@ class TotpAuthModule(MultiFactorAuthModule): Mfa module should extend SetupFlow """ - user = await self.hass.auth.async_get_user(user_id) # type: ignore + user = await self.hass.auth.async_get_user(user_id) + assert user is not None return TotpSetupFlow(self, self.input_schema, user) async def async_setup_user(self, user_id: str, setup_data: Any) -> str: diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 1fa70e42b3f..35208bd847c 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -175,7 +175,7 @@ class LoginFlow(data_entry_flow.FlowHandler): """Initialize the login flow.""" self._auth_provider = auth_provider self._auth_module_id: Optional[str] = None - self._auth_manager = auth_provider.hass.auth # type: ignore + self._auth_manager = auth_provider.hass.auth self.available_mfa_modules: Dict[str, str] = {} self.created_at = dt_util.utcnow() self.invalid_mfa_times = 0 @@ -224,6 +224,7 @@ class LoginFlow(data_entry_flow.FlowHandler): errors = {} + assert self._auth_module_id is not None auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id) if auth_module is None: # Given an invalid input to async_step_select_mfa_module @@ -234,7 +235,9 @@ class LoginFlow(data_entry_flow.FlowHandler): auth_module, "async_initialize_login_mfa_step" ): try: - await auth_module.async_initialize_login_mfa_step(self.user.id) + await auth_module.async_initialize_login_mfa_step( # type: ignore + self.user.id + ) except HomeAssistantError: _LOGGER.exception("Error initializing MFA step") return self.async_abort(reason="unknown_error") diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 584255764a3..10ef2aeecb0 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -4,7 +4,7 @@ import voluptuous as vol import voluptuous_serialize from homeassistant import config_entries, data_entry_flow -from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES +from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT from homeassistant.components import websocket_api from homeassistant.components.http import HomeAssistantView from homeassistant.const import HTTP_NOT_FOUND @@ -180,7 +180,7 @@ class OptionManagerFlowIndexView(FlowManagerIndexView): handler in request is entry_id. """ if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") + raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) # pylint: disable=no-value-for-parameter return await super().post(request) @@ -195,7 +195,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView): async def get(self, request, flow_id): """Get the current state of a data_entry_flow.""" if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") + raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) return await super().get(request, flow_id) @@ -203,7 +203,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView): async def post(self, request, flow_id): """Handle a POST request.""" if not request["hass_user"].is_admin: - raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") + raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT) # pylint: disable=no-value-for-parameter return await super().post(request, flow_id) diff --git a/homeassistant/components/switcher_kis/__init__.py b/homeassistant/components/switcher_kis/__init__.py index 8369fdd8975..4a9a564ec3b 100644 --- a/homeassistant/components/switcher_kis/__init__.py +++ b/homeassistant/components/switcher_kis/__init__.py @@ -85,18 +85,12 @@ async def _validate_edit_permission( """Use for validating user control permissions.""" splited = split_entity_id(entity_id) if splited[0] != SWITCH_DOMAIN or not splited[1].startswith(DOMAIN): - raise Unauthorized( - context=context, entity_id=entity_id, permission=(POLICY_EDIT,) - ) + raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT) user = await hass.auth.async_get_user(context.user_id) if user is None: - raise UnknownUser( - context=context, entity_id=entity_id, permission=(POLICY_EDIT,) - ) + raise UnknownUser(context=context, entity_id=entity_id, permission=POLICY_EDIT) if not user.permissions.check_entity(entity_id, POLICY_EDIT): - raise Unauthorized( - context=context, entity_id=entity_id, permission=(POLICY_EDIT,) - ) + raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT) async def async_setup(hass: HomeAssistantType, config: Dict) -> bool: diff --git a/homeassistant/core.py b/homeassistant/core.py index eb7457daecb..1df05150b14 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -79,6 +79,7 @@ from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitS # Typing imports that create a circular dependency if TYPE_CHECKING: + from homeassistant.auth import AuthManager from homeassistant.config_entries import ConfigEntries from homeassistant.components.http import HomeAssistantHTTP @@ -174,6 +175,7 @@ class CoreState(enum.Enum): class HomeAssistant: """Root object of the Home Assistant home automation.""" + auth: "AuthManager" http: "HomeAssistantHTTP" = None # type: ignore config_entries: "ConfigEntries" = None # type: ignore diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index 745d80d386b..d085c1a9021 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -1,5 +1,5 @@ """The exceptions used by Home Assistant.""" -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional import jinja2 @@ -49,7 +49,7 @@ class Unauthorized(HomeAssistantError): entity_id: Optional[str] = None, config_entry_id: Optional[str] = None, perm_category: Optional[str] = None, - permission: Optional[Tuple[str]] = None, + permission: Optional[str] = None, ) -> None: """Unauthorized error.""" super().__init__(self.__class__.__name__) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index b5d36f6a2f5..29bf1a180a9 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta import functools as ft import logging from timeit import default_timer as timer -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union from homeassistant.config import DATA_CUSTOMIZE from homeassistant.const import ( @@ -32,11 +32,10 @@ from homeassistant.helpers.entity_registry import ( EVENT_ENTITY_REGISTRY_UPDATED, RegistryEntry, ) +from homeassistant.helpers.event import Event from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from homeassistant.util.async_ import run_callback_threadsafe -# mypy: allow-untyped-defs, no-check-untyped-defs - _LOGGER = logging.getLogger(__name__) SLOW_UPDATE_WARNING = 10 @@ -258,7 +257,7 @@ class Entity(ABC): self._context = context self._context_set = dt_util.utcnow() - async def async_update_ha_state(self, force_refresh=False): + async def async_update_ha_state(self, force_refresh: bool = False) -> None: """Update Home Assistant with current state of entity. If force_refresh == True will update entity before setting state. @@ -294,14 +293,15 @@ class Entity(ABC): f"No entity id specified for entity {self.name}" ) - self._async_write_ha_state() # type: ignore + self._async_write_ha_state() @callback - def _async_write_ha_state(self): + def _async_write_ha_state(self) -> None: """Write the state to the state machine.""" if self.registry_entry and self.registry_entry.disabled_by: if not self._disabled_reported: self._disabled_reported = True + assert self.platform is not None _LOGGER.warning( "Entity %s is incorrectly being triggered for updates while it is disabled. This is a bug in the %s integration.", self.entity_id, @@ -317,9 +317,8 @@ class Entity(ABC): if not self.available: state = STATE_UNAVAILABLE else: - state = self.state - - state = STATE_UNKNOWN if state is None else str(state) + sstate = self.state + state = STATE_UNKNOWN if sstate is None else str(sstate) attr.update(self.state_attributes or {}) attr.update(self.device_state_attributes or {}) @@ -383,6 +382,7 @@ class Entity(ABC): ) # Overwrite properties that have been set in the config file. + assert self.hass is not None if DATA_CUSTOMIZE in self.hass.data: attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id)) @@ -403,7 +403,7 @@ class Entity(ABC): pass if ( - self._context is not None + self._context_set is not None and dt_util.utcnow() - self._context_set > self.context_recent_time ): self._context = None @@ -413,7 +413,7 @@ class Entity(ABC): self.entity_id, state, attr, self.force_update, self._context ) - def schedule_update_ha_state(self, force_refresh=False): + def schedule_update_ha_state(self, force_refresh: bool = False) -> None: """Schedule an update ha state change task. Scheduling the update avoids executor deadlocks. @@ -423,10 +423,11 @@ class Entity(ABC): If state is changed more than once before the ha state change task has been executed, the intermediate state transitions will be missed. """ - self.hass.add_job(self.async_update_ha_state(force_refresh)) + assert self.hass is not None + self.hass.add_job(self.async_update_ha_state(force_refresh)) # type: ignore @callback - def async_schedule_update_ha_state(self, force_refresh=False): + def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None: """Schedule an update ha state change task. This method must be run in the event loop. @@ -438,11 +439,12 @@ class Entity(ABC): been executed, the intermediate state transitions will be missed. """ if force_refresh: + assert self.hass is not None self.hass.async_create_task(self.async_update_ha_state(force_refresh)) else: self.async_write_ha_state() - async def async_device_update(self, warning=True): + async def async_device_update(self, warning: bool = True) -> None: """Process 'update' or 'async_update' from entity. This method is a coroutine. @@ -455,6 +457,7 @@ class Entity(ABC): if self.parallel_updates: await self.parallel_updates.acquire() + assert self.hass is not None if warning: update_warn = self.hass.loop.call_later( SLOW_UPDATE_WARNING, @@ -467,9 +470,11 @@ class Entity(ABC): try: # pylint: disable=no-member if hasattr(self, "async_update"): - await self.async_update() + await self.async_update() # type: ignore elif hasattr(self, "update"): - await self.hass.async_add_executor_job(self.update) + await self.hass.async_add_executor_job( + self.update # type: ignore + ) finally: self._update_staged = False if warning: @@ -534,7 +539,7 @@ class Entity(ABC): Not to be extended by integrations. """ - async def _async_registry_updated(self, event): + async def _async_registry_updated(self, event: Event) -> None: """Handle entity registry update.""" data = event.data if data["action"] == "remove" and data["entity_id"] == self.entity_id: @@ -547,24 +552,28 @@ class Entity(ABC): ): return + assert self.hass is not None ent_reg = await self.hass.helpers.entity_registry.async_get_registry() old = self.registry_entry self.registry_entry = ent_reg.async_get(data["entity_id"]) + assert self.registry_entry is not None if self.registry_entry.disabled_by is not None: await self.async_remove() return + assert old is not None if self.registry_entry.entity_id == old.entity_id: self.async_write_ha_state() return await self.async_remove() + assert self.platform is not None self.entity_id = self.registry_entry.entity_id await self.platform.async_add_entities([self]) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Return the comparison.""" if not isinstance(other, self.__class__): return False @@ -587,8 +596,7 @@ class Entity(ABC): """Return the representation.""" return f"" - # call an requests - async def async_request_call(self, coro): + async def async_request_call(self, coro: Awaitable) -> None: """Process request batched.""" if self.parallel_updates: await self.parallel_updates.acquire() @@ -617,16 +625,18 @@ class ToggleEntity(Entity): """Turn the entity on.""" raise NotImplementedError() - async def async_turn_on(self, **kwargs): + async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" + assert self.hass is not None await self.hass.async_add_executor_job(ft.partial(self.turn_on, **kwargs)) def turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" raise NotImplementedError() - async def async_turn_off(self, **kwargs): + async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" + assert self.hass is not None await self.hass.async_add_executor_job(ft.partial(self.turn_off, **kwargs)) def toggle(self, **kwargs: Any) -> None: @@ -636,7 +646,7 @@ class ToggleEntity(Entity): else: self.turn_on(**kwargs) - async def async_toggle(self, **kwargs): + async def async_toggle(self, **kwargs: Any) -> None: """Toggle the entity.""" if self.is_on: await self.async_turn_off(**kwargs) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 5eb5b213732..28fd83d99c1 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -542,7 +542,7 @@ class EntityPlatform: for entity in self.entities.values(): if not entity.should_poll: continue - tasks.append(entity.async_update_ha_state(True)) # type: ignore + tasks.append(entity.async_update_ha_state(True)) if tasks: await asyncio.gather(*tasks) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index af4bdb50fa4..2c4f02990bf 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -505,7 +505,7 @@ def async_register_admin_service( """Register a service that requires admin access.""" @wraps(service_func) - async def admin_handler(call): + async def admin_handler(call: ha.ServiceCall) -> None: if call.context.user_id: user = await hass.auth.async_get_user(call.context.user_id) if user is None: