Keep track what devices support Assist features (#93990)
parent
faacf1658f
commit
65b62d877d
|
@ -949,6 +949,7 @@ class PipelineData:
|
|||
|
||||
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
||||
pipeline_store: PipelineStorageCollection
|
||||
pipeline_devices: set[str] = field(default_factory=set, init=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .pipeline import PipelineStorageCollection
|
||||
from .pipeline import PipelineData, PipelineStorageCollection
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
@ -60,15 +60,24 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
|||
"""When entity is added to Home Assistant."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||
DOMAIN
|
||||
].pipeline_store
|
||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
self.async_on_remove(
|
||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||
)
|
||||
|
||||
state = await self.async_get_last_state()
|
||||
if state is not None and state.state in self.options:
|
||||
self._attr_current_option = state.state
|
||||
|
||||
if self.registry_entry and (device_id := self.registry_entry.device_id):
|
||||
pipeline_data.pipeline_devices.add(device_id)
|
||||
self.async_on_remove(
|
||||
lambda: pipeline_data.pipeline_devices.discard(
|
||||
device_id # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
self._attr_current_option = option
|
||||
|
|
|
@ -280,7 +280,6 @@ def websocket_get_run(
|
|||
)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_pipeline/language/list",
|
||||
|
|
|
@ -9,7 +9,10 @@ import pytest
|
|||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
@ -260,6 +263,12 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
|
||||
"""Return pipeline data."""
|
||||
return hass.data[DOMAIN]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
|
||||
"""Return pipeline storage collection."""
|
||||
return hass.data[DOMAIN].pipeline_store
|
||||
return pipeline_data.pipeline_store
|
||||
|
|
|
@ -5,10 +5,15 @@ from __future__ import annotations
|
|||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import Pipeline
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers.entity import DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
|
||||
|
@ -25,7 +30,11 @@ class SelectPlatform(MockPlatform):
|
|||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up fake select platform."""
|
||||
async_add_entities([AssistPipelineSelect(hass, "test")])
|
||||
entity = AssistPipelineSelect(hass, "test")
|
||||
entity._attr_device_info = DeviceInfo(
|
||||
identifiers={("test", "test")},
|
||||
)
|
||||
async_add_entities([entity])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -33,6 +42,7 @@ async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
|
|||
"""Initialize select entity."""
|
||||
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
|
||||
config_entry = MockConfigEntry(domain="assist_pipeline")
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||
return config_entry
|
||||
|
||||
|
@ -77,6 +87,25 @@ async def pipeline_2(
|
|||
)
|
||||
|
||||
|
||||
async def test_select_entity_registering_device(
|
||||
hass: HomeAssistant,
|
||||
init_select: ConfigEntry,
|
||||
pipeline_data: PipelineData,
|
||||
) -> None:
|
||||
"""Test entity registering as an assist device."""
|
||||
dev_reg = dr.async_get(hass)
|
||||
device = dev_reg.async_get_device({("test", "test")})
|
||||
|
||||
# Test device is registered
|
||||
assert pipeline_data.pipeline_devices == {device.id}
|
||||
|
||||
await hass.config_entries.async_remove(init_select.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Test device is removed
|
||||
assert pipeline_data.pipeline_devices == set()
|
||||
|
||||
|
||||
async def test_select_entity_changing_pipelines(
|
||||
hass: HomeAssistant,
|
||||
init_select: ConfigEntry,
|
||||
|
|
Loading…
Reference in New Issue