Keep track what devices support Assist features (#93990)

pull/93940/head^2
Paulus Schoutsen 2023-06-03 09:26:28 -04:00 committed by GitHub
parent faacf1658f
commit 65b62d877d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 11 deletions

View File

@ -949,6 +949,7 @@ class PipelineData:
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
pipeline_store: PipelineStorageCollection
pipeline_devices: set[str] = field(default_factory=set, init=False)
@dataclass

View File

@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .pipeline import PipelineStorageCollection
from .pipeline import PipelineData, PipelineStorageCollection
OPTION_PREFERRED = "preferred"
@ -60,15 +60,24 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
pipeline_store: PipelineStorageCollection = self.hass.data[
DOMAIN
].pipeline_store
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
self.async_on_remove(
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
)
state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state
if self.registry_entry and (device_id := self.registry_entry.device_id):
pipeline_data.pipeline_devices.add(device_id)
self.async_on_remove(
lambda: pipeline_data.pipeline_devices.discard(
device_id # type: ignore[arg-type]
)
)
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option

View File

@ -280,7 +280,6 @@ def websocket_get_run(
)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/language/list",

View File

@ -9,7 +9,10 @@ import pytest
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
@ -260,6 +263,12 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
"""Return pipeline data."""
return hass.data[DOMAIN]
@pytest.fixture
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
"""Return pipeline storage collection."""
return hass.data[DOMAIN].pipeline_store
return pipeline_data.pipeline_store

View File

@ -5,10 +5,15 @@ from __future__ import annotations
import pytest
from homeassistant.components.assist_pipeline import Pipeline
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.components.assist_pipeline.pipeline import (
PipelineData,
PipelineStorageCollection,
)
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
@ -25,7 +30,11 @@ class SelectPlatform(MockPlatform):
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up fake select platform."""
async_add_entities([AssistPipelineSelect(hass, "test")])
entity = AssistPipelineSelect(hass, "test")
entity._attr_device_info = DeviceInfo(
identifiers={("test", "test")},
)
async_add_entities([entity])
@pytest.fixture
@ -33,6 +42,7 @@ async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
"""Initialize select entity."""
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
config_entry = MockConfigEntry(domain="assist_pipeline")
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
return config_entry
@ -77,6 +87,25 @@ async def pipeline_2(
)
async def test_select_entity_registering_device(
hass: HomeAssistant,
init_select: ConfigEntry,
pipeline_data: PipelineData,
) -> None:
"""Test entity registering as an assist device."""
dev_reg = dr.async_get(hass)
device = dev_reg.async_get_device({("test", "test")})
# Test device is registered
assert pipeline_data.pipeline_devices == {device.id}
await hass.config_entries.async_remove(init_select.entry_id)
await hass.async_block_till_done()
# Test device is removed
assert pipeline_data.pipeline_devices == set()
async def test_select_entity_changing_pipelines(
hass: HomeAssistant,
init_select: ConfigEntry,