Keep track what devices support Assist features (#93990)
parent
faacf1658f
commit
65b62d877d
|
@ -949,6 +949,7 @@ class PipelineData:
|
||||||
|
|
||||||
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
pipeline_runs: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
||||||
pipeline_store: PipelineStorageCollection
|
pipeline_store: PipelineStorageCollection
|
||||||
|
pipeline_devices: set[str] = field(default_factory=set, init=False)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .pipeline import PipelineStorageCollection
|
from .pipeline import PipelineData, PipelineStorageCollection
|
||||||
|
|
||||||
OPTION_PREFERRED = "preferred"
|
OPTION_PREFERRED = "preferred"
|
||||||
|
|
||||||
|
@ -60,15 +60,24 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||||
"""When entity is added to Home Assistant."""
|
"""When entity is added to Home Assistant."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||||
DOMAIN
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
].pipeline_store
|
self.async_on_remove(
|
||||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||||
|
)
|
||||||
|
|
||||||
state = await self.async_get_last_state()
|
state = await self.async_get_last_state()
|
||||||
if state is not None and state.state in self.options:
|
if state is not None and state.state in self.options:
|
||||||
self._attr_current_option = state.state
|
self._attr_current_option = state.state
|
||||||
|
|
||||||
|
if self.registry_entry and (device_id := self.registry_entry.device_id):
|
||||||
|
pipeline_data.pipeline_devices.add(device_id)
|
||||||
|
self.async_on_remove(
|
||||||
|
lambda: pipeline_data.pipeline_devices.discard(
|
||||||
|
device_id # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def async_select_option(self, option: str) -> None:
|
async def async_select_option(self, option: str) -> None:
|
||||||
"""Select an option."""
|
"""Select an option."""
|
||||||
self._attr_current_option = option
|
self._attr_current_option = option
|
||||||
|
|
|
@ -280,7 +280,6 @@ def websocket_get_run(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "assist_pipeline/language/list",
|
vol.Required("type"): "assist_pipeline/language/list",
|
||||||
|
|
|
@ -9,7 +9,10 @@ import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import stt, tts
|
||||||
from homeassistant.components.assist_pipeline import DOMAIN
|
from homeassistant.components.assist_pipeline import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
|
PipelineData,
|
||||||
|
PipelineStorageCollection,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
@ -260,6 +263,12 @@ async def init_components(hass: HomeAssistant, init_supporting_components):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
def pipeline_data(hass: HomeAssistant, init_components) -> PipelineData:
|
||||||
|
"""Return pipeline data."""
|
||||||
|
return hass.data[DOMAIN]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline_storage(pipeline_data) -> PipelineStorageCollection:
|
||||||
"""Return pipeline storage collection."""
|
"""Return pipeline storage collection."""
|
||||||
return hass.data[DOMAIN].pipeline_store
|
return pipeline_data.pipeline_store
|
||||||
|
|
|
@ -5,10 +5,15 @@ from __future__ import annotations
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import Pipeline
|
from homeassistant.components.assist_pipeline import Pipeline
|
||||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
|
PipelineData,
|
||||||
|
PipelineStorageCollection,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
from homeassistant.helpers.entity import DeviceInfo
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
|
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
|
||||||
|
@ -25,7 +30,11 @@ class SelectPlatform(MockPlatform):
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up fake select platform."""
|
"""Set up fake select platform."""
|
||||||
async_add_entities([AssistPipelineSelect(hass, "test")])
|
entity = AssistPipelineSelect(hass, "test")
|
||||||
|
entity._attr_device_info = DeviceInfo(
|
||||||
|
identifiers={("test", "test")},
|
||||||
|
)
|
||||||
|
async_add_entities([entity])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -33,6 +42,7 @@ async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
|
||||||
"""Initialize select entity."""
|
"""Initialize select entity."""
|
||||||
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
|
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
|
||||||
config_entry = MockConfigEntry(domain="assist_pipeline")
|
config_entry = MockConfigEntry(domain="assist_pipeline")
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||||
return config_entry
|
return config_entry
|
||||||
|
|
||||||
|
@ -77,6 +87,25 @@ async def pipeline_2(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_select_entity_registering_device(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_select: ConfigEntry,
|
||||||
|
pipeline_data: PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test entity registering as an assist device."""
|
||||||
|
dev_reg = dr.async_get(hass)
|
||||||
|
device = dev_reg.async_get_device({("test", "test")})
|
||||||
|
|
||||||
|
# Test device is registered
|
||||||
|
assert pipeline_data.pipeline_devices == {device.id}
|
||||||
|
|
||||||
|
await hass.config_entries.async_remove(init_select.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Test device is removed
|
||||||
|
assert pipeline_data.pipeline_devices == set()
|
||||||
|
|
||||||
|
|
||||||
async def test_select_entity_changing_pipelines(
|
async def test_select_entity_changing_pipelines(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
init_select: ConfigEntry,
|
init_select: ConfigEntry,
|
||||||
|
|
Loading…
Reference in New Issue