From 65b62d877db60208bdb35f45d8dd5238d8d1cc79 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 3 Jun 2023 09:26:28 -0400 Subject: [PATCH] Keep track what devices support Assist features (#93990) --- .../components/assist_pipeline/pipeline.py | 1 + .../components/assist_pipeline/select.py | 19 ++++++++--- .../assist_pipeline/websocket_api.py | 1 - tests/components/assist_pipeline/conftest.py | 15 +++++++-- .../components/assist_pipeline/test_select.py | 33 +++++++++++++++++-- 5 files changed, 58 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 031053e8a45..d08e1fc3e50 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -949,6 +949,7 @@ class PipelineData: pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]] pipeline_store: PipelineStorageCollection + pipeline_devices: set[str] = field(default_factory=set, init=False) @dataclass diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index 9ac1d6b5888..8e9f11252be 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import collection, entity_registry as er, restore_state from .const import DOMAIN -from .pipeline import PipelineStorageCollection +from .pipeline import PipelineData, PipelineStorageCollection OPTION_PREFERRED = "preferred" @@ -60,15 +60,24 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): """When entity is added to Home Assistant.""" await super().async_added_to_hass() - pipeline_store: PipelineStorageCollection = self.hass.data[ - DOMAIN - ].pipeline_store - pipeline_store.async_add_change_set_listener(self._pipelines_updated) + pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store + self.async_on_remove( + pipeline_store.async_add_change_set_listener(self._pipelines_updated) + ) state = await self.async_get_last_state() if state is not None and state.state in self.options: self._attr_current_option = state.state + if self.registry_entry and (device_id := self.registry_entry.device_id): + pipeline_data.pipeline_devices.add(device_id) + self.async_on_remove( + lambda: pipeline_data.pipeline_devices.discard( + device_id # type: ignore[arg-type] + ) + ) + async def async_select_option(self, option: str) -> None: """Select an option.""" self._attr_current_option = option diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index 3d8a07dc0b3..bd2ec53db40 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -280,7 +280,6 @@ def websocket_get_run( ) -@callback @websocket_api.websocket_command( { vol.Required("type"): "assist_pipeline/language/list", diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 7b0b98d65a3..5aa760cc606 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -9,7 +9,10 @@ import pytest from homeassistant.components import stt, tts from homeassistant.components.assist_pipeline import DOMAIN -from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection +from homeassistant.components.assist_pipeline.pipeline import ( + PipelineData, + PipelineStorageCollection, +) from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -260,6 +263,12 @@ async def init_components(hass: HomeAssistant, init_supporting_components): @pytest.fixture -def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: +def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData: + """Return pipeline data.""" + return hass.data[DOMAIN] + + +@pytest.fixture +def pipeline_storage(pipeline_data) -> PipelineStorageCollection: """Return pipeline storage collection.""" - return hass.data[DOMAIN].pipeline_store + return pipeline_data.pipeline_store diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index 30874e7b756..2bc580864d7 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -5,10 +5,15 @@ from __future__ import annotations import pytest from homeassistant.components.assist_pipeline import Pipeline -from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection +from homeassistant.components.assist_pipeline.pipeline import ( + PipelineData, + PipelineStorageCollection, +) from homeassistant.components.assist_pipeline.select import AssistPipelineSelect from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform @@ -25,7 +30,11 @@ class SelectPlatform(MockPlatform): async_add_entities: AddEntitiesCallback, ) -> None: """Set up fake select platform.""" - async_add_entities([AssistPipelineSelect(hass, "test")]) + entity = AssistPipelineSelect(hass, "test") + entity._attr_device_info = DeviceInfo( + identifiers={("test", "test")}, + ) + async_add_entities([entity]) @pytest.fixture @@ -33,6 +42,7 @@ async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry: """Initialize select entity.""" mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform()) config_entry = MockConfigEntry(domain="assist_pipeline") + config_entry.add_to_hass(hass) assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") return config_entry @@ -77,6 +87,25 @@ async def pipeline_2( ) +async def test_select_entity_registering_device( + hass: HomeAssistant, + init_select: ConfigEntry, + pipeline_data: PipelineData, +) -> None: + """Test entity registering as an assist device.""" + dev_reg = dr.async_get(hass) + device = dev_reg.async_get_device({("test", "test")}) + + # Test device is registered + assert pipeline_data.pipeline_devices == {device.id} + + await hass.config_entries.async_remove(init_select.entry_id) + await hass.async_block_till_done() + + # Test device is removed + assert pipeline_data.pipeline_devices == set() + + async def test_select_entity_changing_pipelines( hass: HomeAssistant, init_select: ConfigEntry,