Add and fix type hints (#36501)

* Fix exceptions.Unauthorized.permission type

* Use auth.permission consts more

* Auth typing improvements

* Helpers typing improvements

* Calculate self.state only once
pull/36731/head
Ville Skyttä 2020-06-06 21:34:56 +03:00 committed by GitHub
parent 49747684a0
commit 0c5ca3084e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 43 deletions

View File

@ -117,7 +117,8 @@ class TotpAuthModule(MultiFactorAuthModule):
Mfa module should extend SetupFlow
"""
user = await self.hass.auth.async_get_user(user_id) # type: ignore
user = await self.hass.auth.async_get_user(user_id)
assert user is not None
return TotpSetupFlow(self, self.input_schema, user)
async def async_setup_user(self, user_id: str, setup_data: Any) -> str:

View File

@ -175,7 +175,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._auth_module_id: Optional[str] = None
self._auth_manager = auth_provider.hass.auth # type: ignore
self._auth_manager = auth_provider.hass.auth
self.available_mfa_modules: Dict[str, str] = {}
self.created_at = dt_util.utcnow()
self.invalid_mfa_times = 0
@ -224,6 +224,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors = {}
assert self._auth_module_id is not None
auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id)
if auth_module is None:
# Given an invalid input to async_step_select_mfa_module
@ -234,7 +235,9 @@ class LoginFlow(data_entry_flow.FlowHandler):
auth_module, "async_initialize_login_mfa_step"
):
try:
await auth_module.async_initialize_login_mfa_step(self.user.id)
await auth_module.async_initialize_login_mfa_step( # type: ignore
self.user.id
)
except HomeAssistantError:
_LOGGER.exception("Error initializing MFA step")
return self.async_abort(reason="unknown_error")

View File

@ -4,7 +4,7 @@ import voluptuous as vol
import voluptuous_serialize
from homeassistant import config_entries, data_entry_flow
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import HTTP_NOT_FOUND
@ -180,7 +180,7 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
handler in request is entry_id.
"""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request)
@ -195,7 +195,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
async def get(self, request, flow_id):
"""Get the current state of a data_entry_flow."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
return await super().get(request, flow_id)
@ -203,7 +203,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
async def post(self, request, flow_id):
"""Handle a POST request."""
if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit")
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id)

View File

@ -85,18 +85,12 @@ async def _validate_edit_permission(
"""Use for validating user control permissions."""
splited = split_entity_id(entity_id)
if splited[0] != SWITCH_DOMAIN or not splited[1].startswith(DOMAIN):
raise Unauthorized(
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
user = await hass.auth.async_get_user(context.user_id)
if user is None:
raise UnknownUser(
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
raise UnknownUser(context=context, entity_id=entity_id, permission=POLICY_EDIT)
if not user.permissions.check_entity(entity_id, POLICY_EDIT):
raise Unauthorized(
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
async def async_setup(hass: HomeAssistantType, config: Dict) -> bool:

View File

@ -79,6 +79,7 @@ from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitS
# Typing imports that create a circular dependency
if TYPE_CHECKING:
from homeassistant.auth import AuthManager
from homeassistant.config_entries import ConfigEntries
from homeassistant.components.http import HomeAssistantHTTP
@ -174,6 +175,7 @@ class CoreState(enum.Enum):
class HomeAssistant:
"""Root object of the Home Assistant home automation."""
auth: "AuthManager"
http: "HomeAssistantHTTP" = None # type: ignore
config_entries: "ConfigEntries" = None # type: ignore

View File

@ -1,5 +1,5 @@
"""The exceptions used by Home Assistant."""
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import jinja2
@ -49,7 +49,7 @@ class Unauthorized(HomeAssistantError):
entity_id: Optional[str] = None,
config_entry_id: Optional[str] = None,
perm_category: Optional[str] = None,
permission: Optional[Tuple[str]] = None,
permission: Optional[str] = None,
) -> None:
"""Unauthorized error."""
super().__init__(self.__class__.__name__)

View File

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import functools as ft
import logging
from timeit import default_timer as timer
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union
from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.const import (
@ -32,11 +32,10 @@ from homeassistant.helpers.entity_registry import (
EVENT_ENTITY_REGISTRY_UPDATED,
RegistryEntry,
)
from homeassistant.helpers.event import Event
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
@ -258,7 +257,7 @@ class Entity(ABC):
self._context = context
self._context_set = dt_util.utcnow()
async def async_update_ha_state(self, force_refresh=False):
async def async_update_ha_state(self, force_refresh: bool = False) -> None:
"""Update Home Assistant with current state of entity.
If force_refresh == True will update entity before setting state.
@ -294,14 +293,15 @@ class Entity(ABC):
f"No entity id specified for entity {self.name}"
)
self._async_write_ha_state() # type: ignore
self._async_write_ha_state()
@callback
def _async_write_ha_state(self):
def _async_write_ha_state(self) -> None:
"""Write the state to the state machine."""
if self.registry_entry and self.registry_entry.disabled_by:
if not self._disabled_reported:
self._disabled_reported = True
assert self.platform is not None
_LOGGER.warning(
"Entity %s is incorrectly being triggered for updates while it is disabled. This is a bug in the %s integration.",
self.entity_id,
@ -317,9 +317,8 @@ class Entity(ABC):
if not self.available:
state = STATE_UNAVAILABLE
else:
state = self.state
state = STATE_UNKNOWN if state is None else str(state)
sstate = self.state
state = STATE_UNKNOWN if sstate is None else str(sstate)
attr.update(self.state_attributes or {})
attr.update(self.device_state_attributes or {})
@ -383,6 +382,7 @@ class Entity(ABC):
)
# Overwrite properties that have been set in the config file.
assert self.hass is not None
if DATA_CUSTOMIZE in self.hass.data:
attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id))
@ -403,7 +403,7 @@ class Entity(ABC):
pass
if (
self._context is not None
self._context_set is not None
and dt_util.utcnow() - self._context_set > self.context_recent_time
):
self._context = None
@ -413,7 +413,7 @@ class Entity(ABC):
self.entity_id, state, attr, self.force_update, self._context
)
def schedule_update_ha_state(self, force_refresh=False):
def schedule_update_ha_state(self, force_refresh: bool = False) -> None:
"""Schedule an update ha state change task.
Scheduling the update avoids executor deadlocks.
@ -423,10 +423,11 @@ class Entity(ABC):
If state is changed more than once before the ha state change task has
been executed, the intermediate state transitions will be missed.
"""
self.hass.add_job(self.async_update_ha_state(force_refresh))
assert self.hass is not None
self.hass.add_job(self.async_update_ha_state(force_refresh)) # type: ignore
@callback
def async_schedule_update_ha_state(self, force_refresh=False):
def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
"""Schedule an update ha state change task.
This method must be run in the event loop.
@ -438,11 +439,12 @@ class Entity(ABC):
been executed, the intermediate state transitions will be missed.
"""
if force_refresh:
assert self.hass is not None
self.hass.async_create_task(self.async_update_ha_state(force_refresh))
else:
self.async_write_ha_state()
async def async_device_update(self, warning=True):
async def async_device_update(self, warning: bool = True) -> None:
"""Process 'update' or 'async_update' from entity.
This method is a coroutine.
@ -455,6 +457,7 @@ class Entity(ABC):
if self.parallel_updates:
await self.parallel_updates.acquire()
assert self.hass is not None
if warning:
update_warn = self.hass.loop.call_later(
SLOW_UPDATE_WARNING,
@ -467,9 +470,11 @@ class Entity(ABC):
try:
# pylint: disable=no-member
if hasattr(self, "async_update"):
await self.async_update()
await self.async_update() # type: ignore
elif hasattr(self, "update"):
await self.hass.async_add_executor_job(self.update)
await self.hass.async_add_executor_job(
self.update # type: ignore
)
finally:
self._update_staged = False
if warning:
@ -534,7 +539,7 @@ class Entity(ABC):
Not to be extended by integrations.
"""
async def _async_registry_updated(self, event):
async def _async_registry_updated(self, event: Event) -> None:
"""Handle entity registry update."""
data = event.data
if data["action"] == "remove" and data["entity_id"] == self.entity_id:
@ -547,24 +552,28 @@ class Entity(ABC):
):
return
assert self.hass is not None
ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
old = self.registry_entry
self.registry_entry = ent_reg.async_get(data["entity_id"])
assert self.registry_entry is not None
if self.registry_entry.disabled_by is not None:
await self.async_remove()
return
assert old is not None
if self.registry_entry.entity_id == old.entity_id:
self.async_write_ha_state()
return
await self.async_remove()
assert self.platform is not None
self.entity_id = self.registry_entry.entity_id
await self.platform.async_add_entities([self])
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
if not isinstance(other, self.__class__):
return False
@ -587,8 +596,7 @@ class Entity(ABC):
"""Return the representation."""
return f"<Entity {self.name}: {self.state}>"
# call an requests
async def async_request_call(self, coro):
async def async_request_call(self, coro: Awaitable) -> None:
"""Process request batched."""
if self.parallel_updates:
await self.parallel_updates.acquire()
@ -617,16 +625,18 @@ class ToggleEntity(Entity):
"""Turn the entity on."""
raise NotImplementedError()
async def async_turn_on(self, **kwargs):
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the entity on."""
assert self.hass is not None
await self.hass.async_add_executor_job(ft.partial(self.turn_on, **kwargs))
def turn_off(self, **kwargs: Any) -> None:
"""Turn the entity off."""
raise NotImplementedError()
async def async_turn_off(self, **kwargs):
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the entity off."""
assert self.hass is not None
await self.hass.async_add_executor_job(ft.partial(self.turn_off, **kwargs))
def toggle(self, **kwargs: Any) -> None:
@ -636,7 +646,7 @@ class ToggleEntity(Entity):
else:
self.turn_on(**kwargs)
async def async_toggle(self, **kwargs):
async def async_toggle(self, **kwargs: Any) -> None:
"""Toggle the entity."""
if self.is_on:
await self.async_turn_off(**kwargs)

View File

@ -542,7 +542,7 @@ class EntityPlatform:
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True)) # type: ignore
tasks.append(entity.async_update_ha_state(True))
if tasks:
await asyncio.gather(*tasks)

View File

@ -505,7 +505,7 @@ def async_register_admin_service(
"""Register a service that requires admin access."""
@wraps(service_func)
async def admin_handler(call):
async def admin_handler(call: ha.ServiceCall) -> None:
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None: