Keep track what devices support Assist features (#93990)

pull/93940/head^2
Paulus Schoutsen 2023-06-03 09:26:28 -04:00 committed by GitHub
parent faacf1658f
commit 65b62d877d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 11 deletions

View File

@ -949,6 +949,7 @@ class PipelineData:
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]] pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
pipeline_store: PipelineStorageCollection pipeline_store: PipelineStorageCollection
pipeline_devices: set[str] = field(default_factory=set, init=False)
@dataclass @dataclass

View File

@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN from .const import DOMAIN
from .pipeline import PipelineStorageCollection from .pipeline import PipelineData, PipelineStorageCollection
OPTION_PREFERRED = "preferred" OPTION_PREFERRED = "preferred"
@ -60,15 +60,24 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""When entity is added to Home Assistant.""" """When entity is added to Home Assistant."""
await super().async_added_to_hass() await super().async_added_to_hass()
pipeline_store: PipelineStorageCollection = self.hass.data[ pipeline_data: PipelineData = self.hass.data[DOMAIN]
DOMAIN pipeline_store = pipeline_data.pipeline_store
].pipeline_store self.async_on_remove(
pipeline_store.async_add_change_set_listener(self._pipelines_updated) pipeline_store.async_add_change_set_listener(self._pipelines_updated)
)
state = await self.async_get_last_state() state = await self.async_get_last_state()
if state is not None and state.state in self.options: if state is not None and state.state in self.options:
self._attr_current_option = state.state self._attr_current_option = state.state
if self.registry_entry and (device_id := self.registry_entry.device_id):
pipeline_data.pipeline_devices.add(device_id)
self.async_on_remove(
lambda: pipeline_data.pipeline_devices.discard(
device_id # type: ignore[arg-type]
)
)
async def async_select_option(self, option: str) -> None: async def async_select_option(self, option: str) -> None:
"""Select an option.""" """Select an option."""
self._attr_current_option = option self._attr_current_option = option

View File

@ -280,7 +280,6 @@ def websocket_get_run(
) )
@callback
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "assist_pipeline/language/list", vol.Required("type"): "assist_pipeline/language/list",

View File

@ -9,7 +9,10 @@ import pytest
from homeassistant.components import stt, tts from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -260,6 +263,12 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
@pytest.fixture @pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
"""Return pipeline data."""
return hass.data[DOMAIN]
@pytest.fixture
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
"""Return pipeline storage collection.""" """Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store return pipeline_data.pipeline_store

View File

@ -5,10 +5,15 @@ from __future__ import annotations
import pytest import pytest
from homeassistant.components.assist_pipeline import Pipeline from homeassistant.components.assist_pipeline import Pipeline
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
@ -25,7 +30,11 @@ class SelectPlatform(MockPlatform):
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up fake select platform.""" """Set up fake select platform."""
async_add_entities([AssistPipelineSelect(hass, "test")]) entity = AssistPipelineSelect(hass, "test")
entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
async_add_entities([entity])
@pytest.fixture @pytest.fixture
@ -33,6 +42,7 @@ async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
"""Initialize select entity.""" """Initialize select entity."""
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform()) mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
config_entry = MockConfigEntry(domain="assist_pipeline") config_entry = MockConfigEntry(domain="assist_pipeline")
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
return config_entry return config_entry
@ -77,6 +87,25 @@ async def pipeline_2(
) )
async def test_select_entity_registering_device(
hass: HomeAssistant,
init_select: ConfigEntry,
pipeline_data: PipelineData,
) -> None:
"""Test entity registering as an assist device."""
dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device({("test", "test")})
# Test device is registered
assert pipeline_data.pipeline_devices == {device.id}
await hass.config_entries.async_remove(init_select.entry_id)
await hass.async_block_till_done()
# Test device is removed
assert pipeline_data.pipeline_devices == set()
async def test_select_entity_changing_pipelines( async def test_select_entity_changing_pipelines(
hass: HomeAssistant, hass: HomeAssistant,
init_select: ConfigEntry, init_select: ConfigEntry,