Use runtime data in HEOS (#132030)

* Adopt runtime_data

* Fix missing variable assignment

* Address PR feedback
pull/132062/head
Andrew Sayre 2024-12-02 01:19:43 -06:00 committed by GitHub
parent 4eb5734d73
commit 4eb75a56e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 88 deletions

View File

@ -3,10 +3,11 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import timedelta
import logging
from pyheos import Heos, HeosError, const as heos_const
from pyheos import Heos, HeosError, HeosPlayer, const as heos_const
import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
@ -27,10 +28,6 @@ from .config_flow import format_title
from .const import (
COMMAND_RETRY_ATTEMPTS,
COMMAND_RETRY_DELAY,
DATA_CONTROLLER_MANAGER,
DATA_ENTITY_ID_MAP,
DATA_GROUP_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED,
@ -51,6 +48,19 @@ MIN_UPDATE_SOURCES = timedelta(seconds=1)
_LOGGER = logging.getLogger(__name__)
@dataclass
class HeosRuntimeData:
"""Runtime data and coordinators for HEOS config entries."""
controller_manager: ControllerManager
group_manager: GroupManager
source_manager: SourceManager
players: dict[int, HeosPlayer]
type HeosConfigEntry = ConfigEntry[HeosRuntimeData]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HEOS component."""
if DOMAIN not in config:
@ -75,7 +85,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool:
"""Initialize config entry which represents the HEOS controller."""
# For backwards compat
if entry.unique_id is None:
@ -128,17 +138,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
source_manager = SourceManager(favorites, inputs)
source_manager.connect_update(hass, controller)
group_manager = GroupManager(hass, controller)
group_manager = GroupManager(hass, controller, players)
hass.data[DOMAIN] = {
DATA_CONTROLLER_MANAGER: controller_manager,
DATA_GROUP_MANAGER: group_manager,
DATA_SOURCE_MANAGER: source_manager,
Platform.MEDIA_PLAYER: players,
# Maps player_id to entity_id. Populated by the individual
# HeosMediaPlayer entities.
DATA_ENTITY_ID_MAP: {},
}
entry.runtime_data = HeosRuntimeData(
controller_manager, group_manager, source_manager, players
)
services.register(hass, controller)
group_manager.connect_update()
@ -149,11 +153,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool:
"""Unload a config entry."""
controller_manager = hass.data[DOMAIN][DATA_CONTROLLER_MANAGER]
await controller_manager.disconnect()
hass.data.pop(DOMAIN)
await entry.runtime_data.controller_manager.disconnect()
services.remove(hass)
@ -246,21 +248,25 @@ class ControllerManager:
class GroupManager:
"""Class that manages HEOS groups."""
def __init__(self, hass, controller):
def __init__(
self, hass: HomeAssistant, controller: Heos, players: dict[int, HeosPlayer]
) -> None:
"""Init group manager."""
self._hass = hass
self._group_membership = {}
self._group_membership: dict[str, str] = {}
self._disconnect_player_added = None
self._initialized = False
self.controller = controller
self.players = players
self.entity_id_map: dict[int, str] = {}
def _get_entity_id_to_player_id_map(self) -> dict:
"""Return mapping of all HeosMediaPlayer entity_ids to player_ids."""
return {v: k for k, v in self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP].items()}
return {v: k for k, v in self.entity_id_map.items()}
async def async_get_group_membership(self):
async def async_get_group_membership(self) -> dict[str, list[str]]:
"""Return all group members for each player as entity_ids."""
group_info_by_entity_id = {
group_info_by_entity_id: dict[str, list[str]] = {
player_entity_id: []
for player_entity_id in self._get_entity_id_to_player_id_map()
}
@ -271,7 +277,7 @@ class GroupManager:
_LOGGER.error("Unable to get HEOS group info: %s", err)
return group_info_by_entity_id
player_id_to_entity_id_map = self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP]
player_id_to_entity_id_map = self.entity_id_map
for group in groups.values():
leader_entity_id = player_id_to_entity_id_map.get(group.leader.player_id)
member_entity_ids = [
@ -282,9 +288,9 @@ class GroupManager:
# Make sure the group leader is always the first element
group_info = [leader_entity_id, *member_entity_ids]
if leader_entity_id:
group_info_by_entity_id[leader_entity_id] = group_info
group_info_by_entity_id[leader_entity_id] = group_info # type: ignore[assignment]
for member_entity_id in member_entity_ids:
group_info_by_entity_id[member_entity_id] = group_info
group_info_by_entity_id[member_entity_id] = group_info # type: ignore[assignment]
return group_info_by_entity_id
@ -358,13 +364,9 @@ class GroupManager:
# When adding a new HEOS player we need to update the groups.
async def _async_handle_player_added():
# Avoid calling async_update_groups when `DATA_ENTITY_ID_MAP` has not been
# Avoid calling async_update_groups when the entity_id map has not been
# fully populated yet. This may only happen during early startup.
if (
len(self._hass.data[DOMAIN][Platform.MEDIA_PLAYER])
<= len(self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP])
and not self._initialized
):
if len(self.players) <= len(self.entity_id_map) and not self._initialized:
self._initialized = True
await self.async_update_groups(SIGNAL_HEOS_PLAYER_ADDED)

View File

@ -4,10 +4,6 @@ ATTR_PASSWORD = "password"
ATTR_USERNAME = "username"
COMMAND_RETRY_ATTEMPTS = 2
COMMAND_RETRY_DELAY = 1
DATA_CONTROLLER_MANAGER = "controller"
DATA_ENTITY_ID_MAP = "entity_id_map"
DATA_GROUP_MANAGER = "group_manager"
DATA_SOURCE_MANAGER = "source_manager"
DATA_DISCOVERED_HOSTS = "heos_discovered_hosts"
DOMAIN = "heos"
SERVICE_SIGN_IN = "sign_in"

View File

@ -13,7 +13,6 @@ from pyheos import HeosError, const as heos_const
from homeassistant.components import media_source
from homeassistant.components.media_player import (
ATTR_MEDIA_ENQUEUE,
DOMAIN as MEDIA_PLAYER_DOMAIN,
BrowseMedia,
MediaPlayerEnqueue,
MediaPlayerEntity,
@ -22,7 +21,6 @@ from homeassistant.components.media_player import (
MediaType,
async_process_play_media_url,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import (
@ -32,14 +30,8 @@ from homeassistant.helpers.dispatcher import (
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utcnow
from .const import (
DATA_ENTITY_ID_MAP,
DATA_GROUP_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN as HEOS_DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED,
)
from . import GroupManager, HeosConfigEntry, SourceManager
from .const import DOMAIN as HEOS_DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
BASE_SUPPORTED_FEATURES = (
MediaPlayerEntityFeature.VOLUME_MUTE
@ -80,11 +72,16 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
hass: HomeAssistant, entry: HeosConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Add media players for a config entry."""
players = hass.data[HEOS_DOMAIN][MEDIA_PLAYER_DOMAIN]
devices = [HeosMediaPlayer(player) for player in players.values()]
players = entry.runtime_data.players
devices = [
HeosMediaPlayer(
player, entry.runtime_data.source_manager, entry.runtime_data.group_manager
)
for player in players.values()
]
async_add_entities(devices, True)
@ -120,13 +117,15 @@ class HeosMediaPlayer(MediaPlayerEntity):
_attr_has_entity_name = True
_attr_name = None
def __init__(self, player):
def __init__(
self, player, source_manager: SourceManager, group_manager: GroupManager
) -> None:
"""Initialize."""
self._media_position_updated_at = None
self._player = player
self._signals = []
self._source_manager = None
self._group_manager = None
self._signals: list = []
self._source_manager = source_manager
self._group_manager = group_manager
self._attr_unique_id = str(player.player_id)
self._attr_device_info = DeviceInfo(
identifiers={(HEOS_DOMAIN, player.player_id)},
@ -161,9 +160,7 @@ class HeosMediaPlayer(MediaPlayerEntity):
async_dispatcher_connect(self.hass, SIGNAL_HEOS_UPDATED, self._heos_updated)
)
# Register this player's entity_id so it can be resolved by the group manager
self.hass.data[HEOS_DOMAIN][DATA_ENTITY_ID_MAP][self._player.player_id] = (
self.entity_id
)
self._group_manager.entity_id_map[self._player.player_id] = self.entity_id
async_dispatcher_send(self.hass, SIGNAL_HEOS_PLAYER_ADDED)
@log_command_error("clear playlist")
@ -294,12 +291,6 @@ class HeosMediaPlayer(MediaPlayerEntity):
ior, current_support, BASE_SUPPORTED_FEATURES
)
if self._group_manager is None:
self._group_manager = self.hass.data[HEOS_DOMAIN][DATA_GROUP_MANAGER]
if self._source_manager is None:
self._source_manager = self.hass.data[HEOS_DOMAIN][DATA_SOURCE_MANAGER]
@log_command_error("unjoin_player")
async def async_unjoin_player(self) -> None:
"""Remove this player from any group."""

View File

@ -8,15 +8,11 @@ import pytest
from homeassistant.components.heos import (
ControllerManager,
HeosRuntimeData,
async_setup_entry,
async_unload_entry,
)
from homeassistant.components.heos.const import (
DATA_CONTROLLER_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN,
)
from homeassistant.components.media_player import DOMAIN as MEDIA_PLAYER_DOMAIN
from homeassistant.components.heos.const import DOMAIN
from homeassistant.const import CONF_HOST
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
@ -92,10 +88,6 @@ async def test_async_setup_entry_loads_platforms(
assert controller.get_favorites.call_count == 1
assert controller.get_input_sources.call_count == 1
controller.disconnect.assert_not_called()
assert hass.data[DOMAIN][DATA_CONTROLLER_MANAGER].controller == controller
assert hass.data[DOMAIN][MEDIA_PLAYER_DOMAIN] == controller.players
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].favorites == favorites
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].inputs == input_sources
async def test_async_setup_entry_not_signed_in_loads_platforms(
@ -121,10 +113,6 @@ async def test_async_setup_entry_not_signed_in_loads_platforms(
assert controller.get_favorites.call_count == 0
assert controller.get_input_sources.call_count == 1
controller.disconnect.assert_not_called()
assert hass.data[DOMAIN][DATA_CONTROLLER_MANAGER].controller == controller
assert hass.data[DOMAIN][MEDIA_PLAYER_DOMAIN] == controller.players
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].favorites == {}
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].inputs == input_sources
assert (
"127.0.0.1 is not logged in to a HEOS account and will be unable to retrieve "
"HEOS favorites: Use the 'heos.sign_in' service to sign-in to a HEOS account"
@ -163,7 +151,8 @@ async def test_async_setup_entry_player_failure(
async def test_unload_entry(hass: HomeAssistant, config_entry, controller) -> None:
"""Test entries are unloaded correctly."""
controller_manager = Mock(ControllerManager)
hass.data[DOMAIN] = {DATA_CONTROLLER_MANAGER: controller_manager}
config_entry.runtime_data = HeosRuntimeData(controller_manager, None, None, {})
with patch.object(
hass.config_entries, "async_forward_entry_unload", return_value=True
) as unload:
@ -186,7 +175,7 @@ async def test_update_sources_retry(
assert await async_setup_component(hass, DOMAIN, config)
controller.get_favorites.reset_mock()
controller.get_input_sources.reset_mock()
source_manager = hass.data[DOMAIN][DATA_SOURCE_MANAGER]
source_manager = config_entry.runtime_data.source_manager
source_manager.retry_delay = 0
source_manager.max_retry_attempts = 1
controller.get_favorites.side_effect = CommandFailedError("Test", "test", 0)

View File

@ -8,11 +8,7 @@ from pyheos.error import HeosError
import pytest
from homeassistant.components.heos import media_player
from homeassistant.components.heos.const import (
DATA_SOURCE_MANAGER,
DOMAIN,
SIGNAL_HEOS_UPDATED,
)
from homeassistant.components.heos.const import DOMAIN, SIGNAL_HEOS_UPDATED
from homeassistant.components.media_player import (
ATTR_GROUP_MEMBERS,
ATTR_INPUT_SOURCE,
@ -106,7 +102,7 @@ async def test_state_attributes(
assert ATTR_INPUT_SOURCE not in state.attributes
assert (
state.attributes[ATTR_INPUT_SOURCE_LIST]
== hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list
== config_entry.runtime_data.source_manager.source_list
)
@ -219,7 +215,7 @@ async def test_updates_from_sources_updated(
const.SIGNAL_CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
)
await event.wait()
source_list = hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list
source_list = config_entry.runtime_data.source_manager.source_list
assert len(source_list) == 2
state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list
@ -318,7 +314,7 @@ async def test_updates_from_user_changed(
const.SIGNAL_CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None
)
await event.wait()
source_list = hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list
source_list = config_entry.runtime_data.source_manager.source_list
assert len(source_list) == 1
state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list