Simplify groups (#63477)

* Simplify group

* Rename async_update to async_update_group_state and mark it as callback

* Simplify _async_start
pull/63586/head
Erik Montnemery 2022-01-07 08:58:45 +01:00 committed by GitHub
parent e222e1b6f0
commit 8bf8709d99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 66 deletions

View File

@ -21,19 +21,13 @@ from homeassistant.const import (
CONF_NAME,
ENTITY_MATCH_ALL,
ENTITY_MATCH_NONE,
EVENT_HOMEASSISTANT_START,
SERVICE_RELOAD,
STATE_OFF,
STATE_ON,
Platform,
)
from homeassistant.core import (
CoreState,
HomeAssistant,
ServiceCall,
callback,
split_entity_id,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback, split_entity_id
from homeassistant.helpers import start
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.helpers.entity_component import EntityComponent
@ -407,21 +401,22 @@ class GroupEntity(Entity):
"""Register listeners."""
async def _update_at_start(_):
await self.async_update()
self.async_update_group_state()
self.async_write_ha_state()
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _update_at_start)
start.async_at_start(self.hass, _update_at_start)
async def async_defer_or_update_ha_state(self) -> None:
@callback
def async_defer_or_update_ha_state(self) -> None:
"""Only update once at start."""
if self.hass.state != CoreState.running:
if not self.hass.is_running:
return
await self.async_update()
self.async_update_group_state()
self.async_write_ha_state()
@abstractmethod
async def async_update(self) -> None:
def async_update_group_state(self) -> None:
"""Abstract method to update the entity."""
@ -636,22 +631,15 @@ class Group(Entity):
self._async_unsub_state_changed()
self._async_unsub_state_changed = None
async def async_update(self):
@callback
def async_update_group_state(self):
"""Query all members and determine current group state."""
self._state = None
self._async_update_group_state()
async def async_added_to_hass(self):
"""Handle addition to Home Assistant."""
if self.hass.state != CoreState.running:
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, self._async_start
)
return
if self.tracking:
self._reset_tracked_state()
self._async_start_tracking()
start.async_at_start(self.hass, self._async_start)
async def async_will_remove_from_hass(self):
"""Handle removal from Home Assistant."""

View File

@ -20,7 +20,7 @@ from homeassistant.const import (
STATE_ON,
STATE_UNAVAILABLE,
)
from homeassistant.core import CoreState, Event, HomeAssistant
from homeassistant.core import Event, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
@ -90,10 +90,11 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
async def async_state_changed_listener(event: Event) -> None:
@callback
def async_state_changed_listener(event: Event) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
await self.async_defer_or_update_ha_state()
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
@ -101,13 +102,10 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
)
)
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass()
async def async_update(self) -> None:
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the binary sensor group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]
filtered_states: list[str] = [x.state for x in all_states if x is not None]
@ -120,7 +118,6 @@ class BinarySensorGroup(GroupEntity, BinarySensorEntity):
states = list(map(lambda x: x == STATE_ON, filtered_states))
state = self.mode(states)
self._attr_is_on = state
self.async_write_ha_state()
@property
def device_class(self) -> str | None:

View File

@ -42,7 +42,7 @@ from homeassistant.const import (
STATE_OPEN,
STATE_OPENING,
)
from homeassistant.core import CoreState, Event, HomeAssistant, State
from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
@ -110,14 +110,14 @@ class CoverGroup(GroupEntity, CoverEntity):
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities}
self._attr_unique_id = unique_id
async def _update_supported_features_event(self, event: Event) -> None:
@callback
def _update_supported_features_event(self, event: Event) -> None:
self.async_set_context(event.context)
if (entity := event.data.get("entity_id")) is not None:
await self.async_update_supported_features(
entity, event.data.get("new_state")
)
self.async_update_supported_features(entity, event.data.get("new_state"))
async def async_update_supported_features(
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
@ -130,7 +130,7 @@ class CoverGroup(GroupEntity, CoverEntity):
for values in self._tilts.values():
values.discard(entity_id)
if update_state:
await self.async_defer_or_update_ha_state()
self.async_defer_or_update_ha_state()
return
features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
@ -162,14 +162,14 @@ class CoverGroup(GroupEntity, CoverEntity):
self._tilts[KEY_POSITION].discard(entity_id)
if update_state:
await self.async_defer_or_update_ha_state()
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
await self.async_update_supported_features(
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
@ -178,9 +178,6 @@ class CoverGroup(GroupEntity, CoverEntity):
)
)
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass()
async def async_open_cover(self, **kwargs: Any) -> None:
@ -253,7 +250,8 @@ class CoverGroup(GroupEntity, CoverEntity):
context=self._context,
)
async def async_update(self) -> None:
@callback
def async_update_group_state(self) -> None:
"""Update state and attributes."""
self._attr_assumed_state = False

View File

@ -34,7 +34,7 @@ from homeassistant.const import (
CONF_UNIQUE_ID,
STATE_ON,
)
from homeassistant.core import CoreState, Event, HomeAssistant, State
from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
@ -125,14 +125,14 @@ class FanGroup(GroupEntity, FanEntity):
"""Return whether or not the fan is currently oscillating."""
return self._oscillating
async def _update_supported_features_event(self, event: Event) -> None:
@callback
def _update_supported_features_event(self, event: Event) -> None:
self.async_set_context(event.context)
if (entity := event.data.get("entity_id")) is not None:
await self.async_update_supported_features(
entity, event.data.get("new_state")
)
self.async_update_supported_features(entity, event.data.get("new_state"))
async def async_update_supported_features(
@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
@ -151,14 +151,14 @@ class FanGroup(GroupEntity, FanEntity):
self._fans[feature].discard(entity_id)
if update_state:
await self.async_defer_or_update_ha_state()
self.async_defer_or_update_ha_state()
async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
await self.async_update_supported_features(
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
@ -167,9 +167,6 @@ class FanGroup(GroupEntity, FanEntity):
)
)
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass()
async def async_set_percentage(self, percentage: int) -> None:
@ -244,7 +241,8 @@ class FanGroup(GroupEntity, FanEntity):
setattr(self, attr, most_frequent_attribute(states, entity_attr))
self._attr_assumed_state |= not attribute_equal(states, entity_attr)
async def async_update(self) -> None:
@callback
def async_update_group_state(self) -> None:
"""Update state and attributes."""
self._attr_assumed_state = False

View File

@ -47,7 +47,7 @@ from homeassistant.const import (
STATE_ON,
STATE_UNAVAILABLE,
)
from homeassistant.core import CoreState, Event, HomeAssistant, State
from homeassistant.core import Event, HomeAssistant, State, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event
@ -129,10 +129,11 @@ class LightGroup(GroupEntity, LightEntity):
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
async def async_state_changed_listener(event: Event) -> None:
@callback
def async_state_changed_listener(event: Event) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
await self.async_defer_or_update_ha_state()
self.async_defer_or_update_ha_state()
self.async_on_remove(
async_track_state_change_event(
@ -140,10 +141,6 @@ class LightGroup(GroupEntity, LightEntity):
)
)
if self.hass.state == CoreState.running:
await self.async_update()
return
await super().async_added_to_hass()
@property
@ -183,7 +180,8 @@ class LightGroup(GroupEntity, LightEntity):
context=self._context,
)
async def async_update(self) -> None:
@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the light group state."""
all_states = [self.hass.states.get(x) for x in self._entity_ids]
states: list[State] = list(filter(None, all_states))