Add type hints to template states (#82582)

* Add type hints to template states

* Undo rename

* Remove invalid mypy issue link
pull/77491/head
epenet 2022-11-23 17:46:51 +01:00 committed by GitHub
parent 95cbf7cca7
commit aa02a53ac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 16 deletions

View File

@ -40,6 +40,7 @@ from homeassistant.const import (
STATE_UNKNOWN,
)
from homeassistant.core import (
Context,
HomeAssistant,
State,
callback,
@ -683,7 +684,7 @@ class AllStates:
if render_info is not None:
render_info.all_states_lifecycle = True
def __iter__(self):
def __iter__(self) -> Generator[TemplateState, None, None]:
"""Return all states."""
self._collect_all()
return _state_generator(self._hass, None)
@ -693,7 +694,7 @@ class AllStates:
self._collect_all_lifecycle()
return self._hass.states.async_entity_ids_count()
def __call__(self, entity_id):
def __call__(self, entity_id: str) -> str:
"""Return the states."""
state = _get_state(self._hass, entity_id)
return STATE_UNKNOWN if state is None else state.state
@ -716,7 +717,7 @@ class DomainStates:
self._hass = hass
self._domain = domain
def __getattr__(self, name):
def __getattr__(self, name: str) -> TemplateState | None:
"""Return the states."""
return _get_state_if_valid(self._hass, f"{self._domain}.{name}")
@ -734,7 +735,7 @@ class DomainStates:
if entity_collect is not None:
entity_collect.domains_lifecycle.add(self._domain)
def __iter__(self):
def __iter__(self) -> Generator[TemplateState, None, None]:
"""Return the iteration over all the states."""
self._collect_domain()
return _state_generator(self._hass, self._domain)
@ -774,7 +775,7 @@ class TemplateStateBase(State):
# Jinja will try __getitem__ first and it avoids the need
# to call is_safe_attribute
def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
"""Return a property as an attribute for jinja."""
if item in _COLLECTABLE_STATE_ATTRIBUTES:
# _collect_state inlined here for performance
@ -788,7 +789,7 @@ class TemplateStateBase(State):
raise KeyError
@property
def entity_id(self):
def entity_id(self) -> str: # type: ignore[override]
"""Wrap State.entity_id.
Intentionally does not collect state
@ -796,49 +797,49 @@ class TemplateStateBase(State):
return self._entity_id
@property
def state(self):
def state(self) -> str: # type: ignore[override]
"""Wrap State.state."""
self._collect_state()
return self._state.state
@property
def attributes(self):
def attributes(self) -> ReadOnlyDict[str, Any]: # type: ignore[override]
"""Wrap State.attributes."""
self._collect_state()
return self._state.attributes
@property
def last_changed(self):
def last_changed(self) -> datetime: # type: ignore[override]
"""Wrap State.last_changed."""
self._collect_state()
return self._state.last_changed
@property
def last_updated(self):
def last_updated(self) -> datetime: # type: ignore[override]
"""Wrap State.last_updated."""
self._collect_state()
return self._state.last_updated
@property
def context(self):
def context(self) -> Context: # type: ignore[override]
"""Wrap State.context."""
self._collect_state()
return self._state.context
@property
def domain(self):
def domain(self) -> str: # type: ignore[override]
"""Wrap State.domain."""
self._collect_state()
return self._state.domain
@property
def object_id(self):
def object_id(self) -> str: # type: ignore[override]
"""Wrap State.object_id."""
self._collect_state()
return self._state.object_id
@property
def name(self):
def name(self) -> str:
"""Wrap State.name."""
self._collect_state()
return self._state.name
@ -882,7 +883,7 @@ class TemplateStateFromEntityId(TemplateStateBase):
super().__init__(hass, collect, entity_id)
@property
def _state(self) -> State: # type: ignore[override] # mypy issue 4125
def _state(self) -> State: # type: ignore[override]
state = self._hass.states.get(self._entity_id)
if not state:
state = State(self._entity_id, STATE_UNKNOWN)
@ -903,7 +904,9 @@ def _template_state_no_collect(hass: HomeAssistant, state: State) -> TemplateSta
return TemplateState(hass, state, collect=False)
def _state_generator(hass: HomeAssistant, domain: str | None) -> Generator:
def _state_generator(
hass: HomeAssistant, domain: str | None
) -> Generator[TemplateState, None, None]:
"""State generator for a domain or all states."""
for state in sorted(hass.states.async_all(domain), key=attrgetter("entity_id")):
yield _template_state_no_collect(hass, state)