Add start_conversation service to Assist Satellite (#134921)
* Add start_conversation service to Assist Satellite * Fix tests * Implement start_conversation in voip * Update homeassistant/components/assist_satellite/entity.py --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>pull/137390/head
parent
9c8d31a3d5
commit
f391438d0a
|
@ -1122,6 +1122,7 @@ class PipelineRun:
|
||||||
context=user_input.context,
|
context=user_input.context,
|
||||||
language=user_input.language,
|
language=user_input.language,
|
||||||
agent_id=user_input.agent_id,
|
agent_id=user_input.agent_id,
|
||||||
|
extra_system_prompt=user_input.extra_system_prompt,
|
||||||
)
|
)
|
||||||
speech = conversation_result.response.speech.get("plain", {}).get(
|
speech = conversation_result.response.speech.get("plain", {}).get(
|
||||||
"speech", ""
|
"speech", ""
|
||||||
|
|
|
@ -63,6 +63,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"async_internal_announce",
|
"async_internal_announce",
|
||||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||||
)
|
)
|
||||||
|
component.async_register_entity_service(
|
||||||
|
"start_conversation",
|
||||||
|
vol.All(
|
||||||
|
cv.make_entity_service_schema(
|
||||||
|
{
|
||||||
|
vol.Optional("start_message"): str,
|
||||||
|
vol.Optional("start_media_id"): str,
|
||||||
|
vol.Optional("extra_system_prompt"): str,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
cv.has_at_least_one_key("start_message", "start_media_id"),
|
||||||
|
),
|
||||||
|
"async_internal_start_conversation",
|
||||||
|
[AssistSatelliteEntityFeature.START_CONVERSATION],
|
||||||
|
)
|
||||||
hass.data[CONNECTION_TEST_DATA] = {}
|
hass.data[CONNECTION_TEST_DATA] = {}
|
||||||
async_register_websocket_api(hass)
|
async_register_websocket_api(hass)
|
||||||
hass.http.register_view(ConnectionTestView())
|
hass.http.register_view(ConnectionTestView())
|
||||||
|
|
|
@ -26,3 +26,6 @@ class AssistSatelliteEntityFeature(IntFlag):
|
||||||
|
|
||||||
ANNOUNCE = 1
|
ANNOUNCE = 1
|
||||||
"""Device supports remotely triggered announcements."""
|
"""Device supports remotely triggered announcements."""
|
||||||
|
|
||||||
|
START_CONVERSATION = 2
|
||||||
|
"""Device supports starting conversations."""
|
||||||
|
|
|
@ -10,7 +10,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Final, Literal, final
|
from typing import Any, Final, Literal, final
|
||||||
|
|
||||||
from homeassistant.components import media_source, stt, tts
|
from homeassistant.components import conversation, media_source, stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
OPTION_PREFERRED,
|
OPTION_PREFERRED,
|
||||||
AudioSettings,
|
AudioSettings,
|
||||||
|
@ -27,6 +27,7 @@ from homeassistant.components.tts import (
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
generate_media_source_id as tts_generate_media_source_id,
|
||||||
)
|
)
|
||||||
from homeassistant.core import Context, callback
|
from homeassistant.core import Context, callback
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity
|
from homeassistant.helpers import entity
|
||||||
from homeassistant.helpers.entity import EntityDescription
|
from homeassistant.helpers.entity import EntityDescription
|
||||||
|
|
||||||
|
@ -117,6 +118,7 @@ class AssistSatelliteEntity(entity.Entity):
|
||||||
|
|
||||||
_run_has_tts: bool = False
|
_run_has_tts: bool = False
|
||||||
_is_announcing = False
|
_is_announcing = False
|
||||||
|
_extra_system_prompt: str | None = None
|
||||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
_attr_tts_options: dict[str, Any] | None = None
|
_attr_tts_options: dict[str, Any] | None = None
|
||||||
_pipeline_task: asyncio.Task | None = None
|
_pipeline_task: asyncio.Task | None = None
|
||||||
|
@ -216,6 +218,60 @@ class AssistSatelliteEntity(entity.Entity):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_internal_start_conversation(
|
||||||
|
self,
|
||||||
|
start_message: str | None = None,
|
||||||
|
start_media_id: str | None = None,
|
||||||
|
extra_system_prompt: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Start a conversation from the satellite.
|
||||||
|
|
||||||
|
If start_media_id is not provided, message is synthesized to
|
||||||
|
audio with the selected pipeline.
|
||||||
|
|
||||||
|
If start_media_id is provided, it is played directly. It is possible
|
||||||
|
to omit the message and the satellite will not show any text.
|
||||||
|
|
||||||
|
Calls async_start_conversation.
|
||||||
|
"""
|
||||||
|
await self._cancel_running_pipeline()
|
||||||
|
|
||||||
|
# The Home Assistant built-in agent doesn't support conversations.
|
||||||
|
pipeline = async_get_pipeline(self.hass, self._resolve_pipeline())
|
||||||
|
if pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Built-in conversation agent does not support starting conversations"
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_message is None:
|
||||||
|
start_message = ""
|
||||||
|
|
||||||
|
announcement = await self._resolve_announcement_media_id(
|
||||||
|
start_message, start_media_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._is_announcing:
|
||||||
|
raise SatelliteBusyError
|
||||||
|
|
||||||
|
self._is_announcing = True
|
||||||
|
# Provide our start info to the LLM so it understands context of incoming message
|
||||||
|
if extra_system_prompt is not None:
|
||||||
|
self._extra_system_prompt = extra_system_prompt
|
||||||
|
else:
|
||||||
|
self._extra_system_prompt = start_message or None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.async_start_conversation(announcement)
|
||||||
|
finally:
|
||||||
|
self._is_announcing = False
|
||||||
|
self._extra_system_prompt = None
|
||||||
|
|
||||||
|
async def async_start_conversation(
|
||||||
|
self, start_announcement: AssistSatelliteAnnouncement
|
||||||
|
) -> None:
|
||||||
|
"""Start a conversation from the satellite."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_accept_pipeline_from_satellite(
|
async def async_accept_pipeline_from_satellite(
|
||||||
self,
|
self,
|
||||||
audio_stream: AsyncIterable[bytes],
|
audio_stream: AsyncIterable[bytes],
|
||||||
|
@ -302,6 +358,7 @@ class AssistSatelliteEntity(entity.Entity):
|
||||||
),
|
),
|
||||||
start_stage=start_stage,
|
start_stage=start_stage,
|
||||||
end_stage=end_stage,
|
end_stage=end_stage,
|
||||||
|
conversation_extra_system_prompt=self._extra_system_prompt,
|
||||||
),
|
),
|
||||||
f"{self.entity_id}_pipeline",
|
f"{self.entity_id}_pipeline",
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,9 @@
|
||||||
"services": {
|
"services": {
|
||||||
"announce": {
|
"announce": {
|
||||||
"service": "mdi:bullhorn"
|
"service": "mdi:bullhorn"
|
||||||
|
},
|
||||||
|
"start_conversation": {
|
||||||
|
"service": "mdi:forum"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,3 +14,23 @@ announce:
|
||||||
required: false
|
required: false
|
||||||
selector:
|
selector:
|
||||||
text:
|
text:
|
||||||
|
start_conversation:
|
||||||
|
target:
|
||||||
|
entity:
|
||||||
|
domain: assist_satellite
|
||||||
|
supported_features:
|
||||||
|
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
|
fields:
|
||||||
|
start_message:
|
||||||
|
required: false
|
||||||
|
example: "You left the lights on in the living room. Turn them off?"
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
start_media_id:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
extra_system_prompt:
|
||||||
|
required: false
|
||||||
|
selector:
|
||||||
|
text:
|
||||||
|
|
|
@ -25,6 +25,24 @@
|
||||||
"description": "The media ID to announce instead of using text-to-speech."
|
"description": "The media ID to announce instead of using text-to-speech."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"start_conversation": {
|
||||||
|
"name": "Start Conversation",
|
||||||
|
"description": "Start a conversation from a satellite.",
|
||||||
|
"fields": {
|
||||||
|
"start_message": {
|
||||||
|
"name": "Message",
|
||||||
|
"description": "The message to start with."
|
||||||
|
},
|
||||||
|
"start_media_id": {
|
||||||
|
"name": "Media ID",
|
||||||
|
"description": "The media ID to start with instead of using text-to-speech."
|
||||||
|
},
|
||||||
|
"extra_system_prompt": {
|
||||||
|
"name": "Extra system prompt",
|
||||||
|
"description": "Provide background information to the AI about the request."
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,10 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||||
_attr_translation_key = "assist_satellite"
|
_attr_translation_key = "assist_satellite"
|
||||||
_attr_name = None
|
_attr_name = None
|
||||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
_attr_supported_features = (
|
||||||
|
AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
| AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -122,6 +125,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
self._check_announcement_ended_task: asyncio.Task | None = None
|
self._check_announcement_ended_task: asyncio.Task | None = None
|
||||||
self._last_chunk_time: float | None = None
|
self._last_chunk_time: float | None = None
|
||||||
self._rtp_port: int | None = None
|
self._rtp_port: int | None = None
|
||||||
|
self._run_pipeline_after_announce: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_entity_id(self) -> str | None:
|
def pipeline_entity_id(self) -> str | None:
|
||||||
|
@ -172,7 +176,17 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
|
|
||||||
Plays announcement in a loop, blocking until the caller hangs up.
|
Plays announcement in a loop, blocking until the caller hangs up.
|
||||||
"""
|
"""
|
||||||
|
await self._do_announce(announcement, run_pipeline_after=False)
|
||||||
|
|
||||||
|
async def _do_announce(
|
||||||
|
self, announcement: AssistSatelliteAnnouncement, run_pipeline_after: bool
|
||||||
|
) -> None:
|
||||||
|
"""Announce media on the satellite.
|
||||||
|
|
||||||
|
Optionally run a voice pipeline after the announcement has finished.
|
||||||
|
"""
|
||||||
self._announcement_future = asyncio.Future()
|
self._announcement_future = asyncio.Future()
|
||||||
|
self._run_pipeline_after_announce = run_pipeline_after
|
||||||
|
|
||||||
if self._rtp_port is None:
|
if self._rtp_port is None:
|
||||||
# Choose random port for RTP
|
# Choose random port for RTP
|
||||||
|
@ -232,12 +246,6 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
"""
|
"""
|
||||||
while self._announcement is not None:
|
while self._announcement is not None:
|
||||||
current_time = time.monotonic()
|
current_time = time.monotonic()
|
||||||
_LOGGER.debug(
|
|
||||||
"%s %s %s",
|
|
||||||
self._last_chunk_time,
|
|
||||||
current_time,
|
|
||||||
self._announcment_start_time,
|
|
||||||
)
|
|
||||||
if (self._last_chunk_time is None) and (
|
if (self._last_chunk_time is None) and (
|
||||||
(current_time - self._announcment_start_time)
|
(current_time - self._announcment_start_time)
|
||||||
> _ANNOUNCEMENT_RING_TIMEOUT
|
> _ANNOUNCEMENT_RING_TIMEOUT
|
||||||
|
@ -263,6 +271,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
|
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)
|
await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)
|
||||||
|
|
||||||
|
async def async_start_conversation(
|
||||||
|
self, start_announcement: AssistSatelliteAnnouncement
|
||||||
|
) -> None:
|
||||||
|
"""Start a conversation from the satellite."""
|
||||||
|
await self._do_announce(start_announcement, run_pipeline_after=True)
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# VoIP
|
# VoIP
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
@ -347,6 +361,9 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
||||||
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
|
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
|
||||||
|
|
||||||
|
if not self._run_pipeline_after_announce:
|
||||||
|
# Delay before looping announcement
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.exception("Unexpected error while playing announcement")
|
_LOGGER.exception("Unexpected error while playing announcement")
|
||||||
|
@ -355,6 +372,11 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||||
self._run_pipeline_task = None
|
self._run_pipeline_task = None
|
||||||
_LOGGER.debug("Announcement finished")
|
_LOGGER.debug("Announcement finished")
|
||||||
|
|
||||||
|
if self._run_pipeline_after_announce:
|
||||||
|
# Clear announcement to allow pipeline to run
|
||||||
|
self._announcement = None
|
||||||
|
self._announcement_future.set_result(None)
|
||||||
|
|
||||||
def _clear_audio_queue(self) -> None:
|
def _clear_audio_queue(self) -> None:
|
||||||
"""Ensure audio queue is empty."""
|
"""Ensure audio queue is empty."""
|
||||||
while not self._audio_queue.empty():
|
while not self._audio_queue.empty():
|
||||||
|
|
|
@ -88,6 +88,7 @@ def _base_components() -> dict[str, ModuleType]:
|
||||||
# pylint: disable-next=import-outside-toplevel
|
# pylint: disable-next=import-outside-toplevel
|
||||||
from homeassistant.components import (
|
from homeassistant.components import (
|
||||||
alarm_control_panel,
|
alarm_control_panel,
|
||||||
|
assist_satellite,
|
||||||
calendar,
|
calendar,
|
||||||
camera,
|
camera,
|
||||||
climate,
|
climate,
|
||||||
|
@ -108,6 +109,7 @@ def _base_components() -> dict[str, ModuleType]:
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"alarm_control_panel": alarm_control_panel,
|
"alarm_control_panel": alarm_control_panel,
|
||||||
|
"assist_satellite": assist_satellite,
|
||||||
"calendar": calendar,
|
"calendar": calendar,
|
||||||
"camera": camera,
|
"camera": camera,
|
||||||
"climate": climate,
|
"climate": climate,
|
||||||
|
|
|
@ -40,6 +40,8 @@ def mock_tts(mock_tts_cache_dir: pathlib.Path) -> None:
|
||||||
class MockAssistSatellite(AssistSatelliteEntity):
|
class MockAssistSatellite(AssistSatelliteEntity):
|
||||||
"""Mock Assist Satellite Entity."""
|
"""Mock Assist Satellite Entity."""
|
||||||
|
|
||||||
|
_attr_tts_options = {"test-option": "test-value"}
|
||||||
|
|
||||||
def __init__(self, name: str, features: AssistSatelliteEntityFeature) -> None:
|
def __init__(self, name: str, features: AssistSatelliteEntityFeature) -> None:
|
||||||
"""Initialize the mock entity."""
|
"""Initialize the mock entity."""
|
||||||
self._attr_unique_id = ulid_hex()
|
self._attr_unique_id = ulid_hex()
|
||||||
|
@ -67,6 +69,7 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
||||||
active_wake_words=["1234"],
|
active_wake_words=["1234"],
|
||||||
max_active_wake_words=1,
|
max_active_wake_words=1,
|
||||||
)
|
)
|
||||||
|
self.start_conversations = []
|
||||||
|
|
||||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
"""Handle pipeline events."""
|
"""Handle pipeline events."""
|
||||||
|
@ -87,11 +90,21 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
||||||
"""Set the current satellite configuration."""
|
"""Set the current satellite configuration."""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
async def async_start_conversation(
|
||||||
|
self, start_announcement: AssistSatelliteConfiguration
|
||||||
|
) -> None:
|
||||||
|
"""Start a conversation from the satellite."""
|
||||||
|
self.start_conversations.append((self._extra_system_prompt, start_announcement))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def entity() -> MockAssistSatellite:
|
def entity() -> MockAssistSatellite:
|
||||||
"""Mock Assist Satellite Entity."""
|
"""Mock Assist Satellite Entity."""
|
||||||
return MockAssistSatellite("Test Entity", AssistSatelliteEntityFeature.ANNOUNCE)
|
return MockAssistSatellite(
|
||||||
|
"Test Entity",
|
||||||
|
AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
| AssistSatelliteEntityFeature.START_CONVERSATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -25,11 +25,24 @@ from homeassistant.components.assist_satellite.entity import AssistSatelliteStat
|
||||||
from homeassistant.components.media_source import PlayMedia
|
from homeassistant.components.media_source import PlayMedia
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from . import ENTITY_ID
|
from . import ENTITY_ID
|
||||||
from .conftest import MockAssistSatellite
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
|
||||||
|
"""Set up a pipeline with a TTS engine."""
|
||||||
|
await async_update_pipeline(
|
||||||
|
hass,
|
||||||
|
async_get_pipeline(hass),
|
||||||
|
tts_engine="tts.mock_entity",
|
||||||
|
tts_language="en",
|
||||||
|
tts_voice="test-voice",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_entity_state(
|
async def test_entity_state(
|
||||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -64,7 +77,7 @@ async def test_entity_state(
|
||||||
assert kwargs["stt_stream"] is audio_stream
|
assert kwargs["stt_stream"] is audio_stream
|
||||||
assert kwargs["pipeline_id"] is None
|
assert kwargs["pipeline_id"] is None
|
||||||
assert kwargs["device_id"] is entity.device_entry.id
|
assert kwargs["device_id"] is entity.device_entry.id
|
||||||
assert kwargs["tts_audio_output"] is None
|
assert kwargs["tts_audio_output"] == {"test-option": "test-value"}
|
||||||
assert kwargs["wake_word_phrase"] is None
|
assert kwargs["wake_word_phrase"] is None
|
||||||
assert kwargs["audio_settings"] == AudioSettings(
|
assert kwargs["audio_settings"] == AudioSettings(
|
||||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||||
|
@ -200,24 +213,12 @@ async def test_announce(
|
||||||
expected_params: tuple[str, str],
|
expected_params: tuple[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test announcing on a device."""
|
"""Test announcing on a device."""
|
||||||
await async_update_pipeline(
|
|
||||||
hass,
|
|
||||||
async_get_pipeline(hass),
|
|
||||||
tts_engine="tts.mock_entity",
|
|
||||||
tts_language="en",
|
|
||||||
tts_voice="test-voice",
|
|
||||||
)
|
|
||||||
|
|
||||||
entity._attr_tts_options = {"test-option": "test-value"}
|
|
||||||
|
|
||||||
original_announce = entity.async_announce
|
original_announce = entity.async_announce
|
||||||
announce_started = asyncio.Event()
|
|
||||||
|
|
||||||
async def async_announce(announcement):
|
async def async_announce(announcement):
|
||||||
# Verify state change
|
# Verify state change
|
||||||
assert entity.state == AssistSatelliteState.RESPONDING
|
assert entity.state == AssistSatelliteState.RESPONDING
|
||||||
await original_announce(announcement)
|
await original_announce(announcement)
|
||||||
announce_started.set()
|
|
||||||
|
|
||||||
def tts_generate_media_source_id(
|
def tts_generate_media_source_id(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -475,3 +476,104 @@ async def test_vad_sensitivity_entity_not_found(
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("service_data", "expected_params"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"start_message": "Hello",
|
||||||
|
"extra_system_prompt": "Better system prompt",
|
||||||
|
},
|
||||||
|
(
|
||||||
|
"Better system prompt",
|
||||||
|
AssistSatelliteAnnouncement(
|
||||||
|
message="Hello",
|
||||||
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
original_media_id="media-source://generated",
|
||||||
|
media_id_source="tts",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"start_message": "Hello",
|
||||||
|
"start_media_id": "media-source://given",
|
||||||
|
},
|
||||||
|
(
|
||||||
|
"Hello",
|
||||||
|
AssistSatelliteAnnouncement(
|
||||||
|
message="Hello",
|
||||||
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
original_media_id="media-source://given",
|
||||||
|
media_id_source="media_id",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"start_media_id": "http://example.com/given.mp3"},
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
AssistSatelliteAnnouncement(
|
||||||
|
message="",
|
||||||
|
media_id="http://example.com/given.mp3",
|
||||||
|
original_media_id="http://example.com/given.mp3",
|
||||||
|
media_id_source="url",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_start_conversation(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
service_data: dict,
|
||||||
|
expected_params: tuple[str, str],
|
||||||
|
) -> None:
|
||||||
|
"""Test starting a conversation on a device."""
|
||||||
|
await async_update_pipeline(
|
||||||
|
hass,
|
||||||
|
async_get_pipeline(hass),
|
||||||
|
conversation_engine="conversation.some_llm",
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||||
|
return_value="media-source://generated",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
|
return_value=PlayMedia(
|
||||||
|
url="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
mime_type="audio/mp3",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"assist_satellite",
|
||||||
|
"start_conversation",
|
||||||
|
service_data,
|
||||||
|
target={"entity_id": "assist_satellite.test_entity"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert entity.start_conversations[0] == expected_params
|
||||||
|
|
||||||
|
|
||||||
|
async def test_start_conversation_reject_builtin_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: ConfigEntry,
|
||||||
|
entity: MockAssistSatellite,
|
||||||
|
) -> None:
|
||||||
|
"""Test starting a conversation on a device."""
|
||||||
|
with pytest.raises(HomeAssistantError):
|
||||||
|
await hass.services.async_call(
|
||||||
|
"assist_satellite",
|
||||||
|
"start_conversation",
|
||||||
|
{"start_message": "Hey!"},
|
||||||
|
target={"entity_id": "assist_satellite.test_entity"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
|
@ -887,6 +887,7 @@ async def test_announce(
|
||||||
|
|
||||||
# Trigger announcement
|
# Trigger announcement
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||||
|
@ -938,6 +939,7 @@ async def test_voip_id_is_ip_address(
|
||||||
|
|
||||||
# Trigger announcement
|
# Trigger announcement
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||||
|
@ -981,3 +983,104 @@ async def test_announce_timeout(
|
||||||
satellite.transport = Mock()
|
satellite.transport = Mock()
|
||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
await satellite.async_announce(announcement)
|
await satellite.async_announce(announcement)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("socket_enabled")
|
||||||
|
async def test_start_conversation(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voip_devices: VoIPDevices,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test start conversation."""
|
||||||
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
assert (
|
||||||
|
satellite.supported_features
|
||||||
|
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
|
)
|
||||||
|
|
||||||
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
|
message="test announcement",
|
||||||
|
media_id=_MEDIA_ID,
|
||||||
|
original_media_id=_MEDIA_ID,
|
||||||
|
media_id_source="tts",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protocol has already been mocked, but "outgoing_call" is not async
|
||||||
|
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
||||||
|
mock_protocol.outgoing_call = Mock()
|
||||||
|
|
||||||
|
tts_sent = asyncio.Event()
|
||||||
|
|
||||||
|
async def _send_tts(*args, **kwargs):
|
||||||
|
tts_sent.set()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context: Context,
|
||||||
|
*args,
|
||||||
|
device_id: str | None,
|
||||||
|
tts_audio_output: str | dict[str, Any] | None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
|
# Fake tts result
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||||
|
data={
|
||||||
|
"engine": "test",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"voice": "test",
|
||||||
|
"tts_input": "fake-text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Proceed with media output
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
type=assist_pipeline.PipelineEventType.RUN_END
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||||
|
new=_send_tts,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
satellite.transport = Mock()
|
||||||
|
conversation_task = hass.async_create_background_task(
|
||||||
|
satellite.async_start_conversation(announcement), "voip_start_conversation"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mock_protocol.outgoing_call.assert_called_once()
|
||||||
|
|
||||||
|
# Trigger announcement and wait for it to finish
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await tts_sent.wait()
|
||||||
|
|
||||||
|
tts_sent.clear()
|
||||||
|
|
||||||
|
# Trigger pipeline
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
# Wait for TTS
|
||||||
|
await tts_sent.wait()
|
||||||
|
await conversation_task
|
||||||
|
|
Loading…
Reference in New Issue