Add and fix type hints (#36501)
* Fix exceptions.Unauthorized.permission type * Use auth.permission consts more * Auth typing improvements * Helpers typing improvements * Calculate self.state only oncepull/36731/head
parent
49747684a0
commit
0c5ca3084e
|
@ -117,7 +117,8 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
|
||||
Mfa module should extend SetupFlow
|
||||
"""
|
||||
user = await self.hass.auth.async_get_user(user_id) # type: ignore
|
||||
user = await self.hass.auth.async_get_user(user_id)
|
||||
assert user is not None
|
||||
return TotpSetupFlow(self, self.input_schema, user)
|
||||
|
||||
async def async_setup_user(self, user_id: str, setup_data: Any) -> str:
|
||||
|
|
|
@ -175,7 +175,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self._auth_module_id: Optional[str] = None
|
||||
self._auth_manager = auth_provider.hass.auth # type: ignore
|
||||
self._auth_manager = auth_provider.hass.auth
|
||||
self.available_mfa_modules: Dict[str, str] = {}
|
||||
self.created_at = dt_util.utcnow()
|
||||
self.invalid_mfa_times = 0
|
||||
|
@ -224,6 +224,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
errors = {}
|
||||
|
||||
assert self._auth_module_id is not None
|
||||
auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id)
|
||||
if auth_module is None:
|
||||
# Given an invalid input to async_step_select_mfa_module
|
||||
|
@ -234,7 +235,9 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
auth_module, "async_initialize_login_mfa_step"
|
||||
):
|
||||
try:
|
||||
await auth_module.async_initialize_login_mfa_step(self.user.id)
|
||||
await auth_module.async_initialize_login_mfa_step( # type: ignore
|
||||
self.user.id
|
||||
)
|
||||
except HomeAssistantError:
|
||||
_LOGGER.exception("Error initializing MFA step")
|
||||
return self.async_abort(reason="unknown_error")
|
||||
|
|
|
@ -4,7 +4,7 @@ import voluptuous as vol
|
|||
import voluptuous_serialize
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES
|
||||
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import HTTP_NOT_FOUND
|
||||
|
@ -180,7 +180,7 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
|
|||
handler in request is entry_id.
|
||||
"""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request)
|
||||
|
@ -195,7 +195,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
|
|||
async def get(self, request, flow_id):
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
return await super().get(request, flow_id)
|
||||
|
||||
|
@ -203,7 +203,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
|
|||
async def post(self, request, flow_id):
|
||||
"""Handle a POST request."""
|
||||
if not request["hass_user"].is_admin:
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
|
||||
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request, flow_id)
|
||||
|
|
|
@ -85,18 +85,12 @@ async def _validate_edit_permission(
|
|||
"""Use for validating user control permissions."""
|
||||
splited = split_entity_id(entity_id)
|
||||
if splited[0] != SWITCH_DOMAIN or not splited[1].startswith(DOMAIN):
|
||||
raise Unauthorized(
|
||||
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
|
||||
)
|
||||
raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
|
||||
user = await hass.auth.async_get_user(context.user_id)
|
||||
if user is None:
|
||||
raise UnknownUser(
|
||||
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
|
||||
)
|
||||
raise UnknownUser(context=context, entity_id=entity_id, permission=POLICY_EDIT)
|
||||
if not user.permissions.check_entity(entity_id, POLICY_EDIT):
|
||||
raise Unauthorized(
|
||||
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
|
||||
)
|
||||
raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistantType, config: Dict) -> bool:
|
||||
|
|
|
@ -79,6 +79,7 @@ from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitS
|
|||
|
||||
# Typing imports that create a circular dependency
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.auth import AuthManager
|
||||
from homeassistant.config_entries import ConfigEntries
|
||||
from homeassistant.components.http import HomeAssistantHTTP
|
||||
|
||||
|
@ -174,6 +175,7 @@ class CoreState(enum.Enum):
|
|||
class HomeAssistant:
|
||||
"""Root object of the Home Assistant home automation."""
|
||||
|
||||
auth: "AuthManager"
|
||||
http: "HomeAssistantHTTP" = None # type: ignore
|
||||
config_entries: "ConfigEntries" = None # type: ignore
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""The exceptions used by Home Assistant."""
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import jinja2
|
||||
|
||||
|
@ -49,7 +49,7 @@ class Unauthorized(HomeAssistantError):
|
|||
entity_id: Optional[str] = None,
|
||||
config_entry_id: Optional[str] = None,
|
||||
perm_category: Optional[str] = None,
|
||||
permission: Optional[Tuple[str]] = None,
|
||||
permission: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Unauthorized error."""
|
||||
super().__init__(self.__class__.__name__)
|
||||
|
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timedelta
|
|||
import functools as ft
|
||||
import logging
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from homeassistant.config import DATA_CUSTOMIZE
|
||||
from homeassistant.const import (
|
||||
|
@ -32,11 +32,10 @@ from homeassistant.helpers.entity_registry import (
|
|||
EVENT_ENTITY_REGISTRY_UPDATED,
|
||||
RegistryEntry,
|
||||
)
|
||||
from homeassistant.helpers.event import Event
|
||||
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SLOW_UPDATE_WARNING = 10
|
||||
|
||||
|
@ -258,7 +257,7 @@ class Entity(ABC):
|
|||
self._context = context
|
||||
self._context_set = dt_util.utcnow()
|
||||
|
||||
async def async_update_ha_state(self, force_refresh=False):
|
||||
async def async_update_ha_state(self, force_refresh: bool = False) -> None:
|
||||
"""Update Home Assistant with current state of entity.
|
||||
|
||||
If force_refresh == True will update entity before setting state.
|
||||
|
@ -294,14 +293,15 @@ class Entity(ABC):
|
|||
f"No entity id specified for entity {self.name}"
|
||||
)
|
||||
|
||||
self._async_write_ha_state() # type: ignore
|
||||
self._async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def _async_write_ha_state(self):
|
||||
def _async_write_ha_state(self) -> None:
|
||||
"""Write the state to the state machine."""
|
||||
if self.registry_entry and self.registry_entry.disabled_by:
|
||||
if not self._disabled_reported:
|
||||
self._disabled_reported = True
|
||||
assert self.platform is not None
|
||||
_LOGGER.warning(
|
||||
"Entity %s is incorrectly being triggered for updates while it is disabled. This is a bug in the %s integration.",
|
||||
self.entity_id,
|
||||
|
@ -317,9 +317,8 @@ class Entity(ABC):
|
|||
if not self.available:
|
||||
state = STATE_UNAVAILABLE
|
||||
else:
|
||||
state = self.state
|
||||
|
||||
state = STATE_UNKNOWN if state is None else str(state)
|
||||
sstate = self.state
|
||||
state = STATE_UNKNOWN if sstate is None else str(sstate)
|
||||
attr.update(self.state_attributes or {})
|
||||
attr.update(self.device_state_attributes or {})
|
||||
|
||||
|
@ -383,6 +382,7 @@ class Entity(ABC):
|
|||
)
|
||||
|
||||
# Overwrite properties that have been set in the config file.
|
||||
assert self.hass is not None
|
||||
if DATA_CUSTOMIZE in self.hass.data:
|
||||
attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id))
|
||||
|
||||
|
@ -403,7 +403,7 @@ class Entity(ABC):
|
|||
pass
|
||||
|
||||
if (
|
||||
self._context is not None
|
||||
self._context_set is not None
|
||||
and dt_util.utcnow() - self._context_set > self.context_recent_time
|
||||
):
|
||||
self._context = None
|
||||
|
@ -413,7 +413,7 @@ class Entity(ABC):
|
|||
self.entity_id, state, attr, self.force_update, self._context
|
||||
)
|
||||
|
||||
def schedule_update_ha_state(self, force_refresh=False):
|
||||
def schedule_update_ha_state(self, force_refresh: bool = False) -> None:
|
||||
"""Schedule an update ha state change task.
|
||||
|
||||
Scheduling the update avoids executor deadlocks.
|
||||
|
@ -423,10 +423,11 @@ class Entity(ABC):
|
|||
If state is changed more than once before the ha state change task has
|
||||
been executed, the intermediate state transitions will be missed.
|
||||
"""
|
||||
self.hass.add_job(self.async_update_ha_state(force_refresh))
|
||||
assert self.hass is not None
|
||||
self.hass.add_job(self.async_update_ha_state(force_refresh)) # type: ignore
|
||||
|
||||
@callback
|
||||
def async_schedule_update_ha_state(self, force_refresh=False):
|
||||
def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
|
||||
"""Schedule an update ha state change task.
|
||||
|
||||
This method must be run in the event loop.
|
||||
|
@ -438,11 +439,12 @@ class Entity(ABC):
|
|||
been executed, the intermediate state transitions will be missed.
|
||||
"""
|
||||
if force_refresh:
|
||||
assert self.hass is not None
|
||||
self.hass.async_create_task(self.async_update_ha_state(force_refresh))
|
||||
else:
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_device_update(self, warning=True):
|
||||
async def async_device_update(self, warning: bool = True) -> None:
|
||||
"""Process 'update' or 'async_update' from entity.
|
||||
|
||||
This method is a coroutine.
|
||||
|
@ -455,6 +457,7 @@ class Entity(ABC):
|
|||
if self.parallel_updates:
|
||||
await self.parallel_updates.acquire()
|
||||
|
||||
assert self.hass is not None
|
||||
if warning:
|
||||
update_warn = self.hass.loop.call_later(
|
||||
SLOW_UPDATE_WARNING,
|
||||
|
@ -467,9 +470,11 @@ class Entity(ABC):
|
|||
try:
|
||||
# pylint: disable=no-member
|
||||
if hasattr(self, "async_update"):
|
||||
await self.async_update()
|
||||
await self.async_update() # type: ignore
|
||||
elif hasattr(self, "update"):
|
||||
await self.hass.async_add_executor_job(self.update)
|
||||
await self.hass.async_add_executor_job(
|
||||
self.update # type: ignore
|
||||
)
|
||||
finally:
|
||||
self._update_staged = False
|
||||
if warning:
|
||||
|
@ -534,7 +539,7 @@ class Entity(ABC):
|
|||
Not to be extended by integrations.
|
||||
"""
|
||||
|
||||
async def _async_registry_updated(self, event):
|
||||
async def _async_registry_updated(self, event: Event) -> None:
|
||||
"""Handle entity registry update."""
|
||||
data = event.data
|
||||
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
|
||||
|
@ -547,24 +552,28 @@ class Entity(ABC):
|
|||
):
|
||||
return
|
||||
|
||||
assert self.hass is not None
|
||||
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
|
||||
old = self.registry_entry
|
||||
self.registry_entry = ent_reg.async_get(data["entity_id"])
|
||||
assert self.registry_entry is not None
|
||||
|
||||
if self.registry_entry.disabled_by is not None:
|
||||
await self.async_remove()
|
||||
return
|
||||
|
||||
assert old is not None
|
||||
if self.registry_entry.entity_id == old.entity_id:
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
await self.async_remove()
|
||||
|
||||
assert self.platform is not None
|
||||
self.entity_id = self.registry_entry.entity_id
|
||||
await self.platform.async_add_entities([self])
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return the comparison."""
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
@ -587,8 +596,7 @@ class Entity(ABC):
|
|||
"""Return the representation."""
|
||||
return f"<Entity {self.name}: {self.state}>"
|
||||
|
||||
# call an requests
|
||||
async def async_request_call(self, coro):
|
||||
async def async_request_call(self, coro: Awaitable) -> None:
|
||||
"""Process request batched."""
|
||||
if self.parallel_updates:
|
||||
await self.parallel_updates.acquire()
|
||||
|
@ -617,16 +625,18 @@ class ToggleEntity(Entity):
|
|||
"""Turn the entity on."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn the entity on."""
|
||||
assert self.hass is not None
|
||||
await self.hass.async_add_executor_job(ft.partial(self.turn_on, **kwargs))
|
||||
|
||||
def turn_off(self, **kwargs: Any) -> None:
|
||||
"""Turn the entity off."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||
"""Turn the entity off."""
|
||||
assert self.hass is not None
|
||||
await self.hass.async_add_executor_job(ft.partial(self.turn_off, **kwargs))
|
||||
|
||||
def toggle(self, **kwargs: Any) -> None:
|
||||
|
@ -636,7 +646,7 @@ class ToggleEntity(Entity):
|
|||
else:
|
||||
self.turn_on(**kwargs)
|
||||
|
||||
async def async_toggle(self, **kwargs):
|
||||
async def async_toggle(self, **kwargs: Any) -> None:
|
||||
"""Toggle the entity."""
|
||||
if self.is_on:
|
||||
await self.async_turn_off(**kwargs)
|
||||
|
|
|
@ -542,7 +542,7 @@ class EntityPlatform:
|
|||
for entity in self.entities.values():
|
||||
if not entity.should_poll:
|
||||
continue
|
||||
tasks.append(entity.async_update_ha_state(True)) # type: ignore
|
||||
tasks.append(entity.async_update_ha_state(True))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
|
|
@ -505,7 +505,7 @@ def async_register_admin_service(
|
|||
"""Register a service that requires admin access."""
|
||||
|
||||
@wraps(service_func)
|
||||
async def admin_handler(call):
|
||||
async def admin_handler(call: ha.ServiceCall) -> None:
|
||||
if call.context.user_id:
|
||||
user = await hass.auth.async_get_user(call.context.user_id)
|
||||
if user is None:
|
||||
|
|
Loading…
Reference in New Issue