Change assist satellite announce method signature (#126299)

pull/126343/head
Paulus Schoutsen 2024-09-20 09:09:37 -04:00 committed by GitHub
parent 41ffa8d6db
commit 604c848dec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 17 deletions

View File

@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
@ -22,6 +23,7 @@ from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteEntity",
"AssistSatelliteConfiguration",
"AssistSatelliteEntityDescription",

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum
import logging
import time
from typing import Any, Final, final
from typing import Any, Final, Literal, final
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
@dataclass
class AssistSatelliteAnnouncement:
"""Announcement to be made."""
message: str
"""Message to be spoken."""
media_id: str
"""Media ID to be played."""
media_id_source: Literal["url", "media_id", "tts"]
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
@ -174,10 +187,13 @@ class AssistSatelliteEntity(entity.Entity):
"""
await self._cancel_running_pipeline()
media_id_source: Literal["url", "media_id", "tts"] | None = None
if message is None:
message = ""
if not media_id:
media_id_source = "tts"
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
@ -198,6 +214,8 @@ class AssistSatelliteEntity(entity.Entity):
)
if media_source.is_media_source_id(media_id):
if not media_id_source:
media_id_source = "media_id"
media = await media_source.async_resolve_media(
self.hass,
media_id,
@ -205,6 +223,9 @@ class AssistSatelliteEntity(entity.Entity):
)
media_id = media.url
if not media_id_source:
media_id_source = "url"
# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)
@ -216,12 +237,14 @@ class AssistSatelliteEntity(entity.Entity):
try:
# Block until announcement is finished
await self.async_announce(message, media_id)
await self.async_announce(
AssistSatelliteAnnouncement(message, media_id, media_id_source)
)
finally:
self._is_announcing = False
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.

View File

@ -313,18 +313,20 @@ class EsphomeAssistSatellite(
self.cli.send_voice_assistant_event(event_type, data_to_send)
async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(
self, announcement: assist_satellite.AssistSatelliteAnnouncement
) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
_LOGGER.debug(
"Waiting for announcement to finished (message=%s, media_id=%s)",
message,
media_id,
announcement.message,
announcement.media_id,
)
await self.cli.send_voice_assistant_announcement_await_response(
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message
announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
)
async def handle_pipeline_start(

View File

@ -8,6 +8,7 @@ import pytest
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
@ -63,9 +64,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Handle pipeline events."""
self.events.append(event)
async def async_announce(self, message: str, media_id: str) -> None:
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on a device."""
self.announcements.append((message, media_id))
self.announcements.append(announcement)
@callback
def async_get_configuration(self) -> AssistSatelliteConfiguration:

View File

@ -17,7 +17,10 @@ from homeassistant.components.assist_pipeline import (
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import SatelliteBusyError
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
SatelliteBusyError,
)
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
@ -159,18 +162,22 @@ async def test_new_pipeline_cancels_pipeline(
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
"media_id": "media-source://bla",
},
("Hello", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
),
],
)
@ -195,10 +202,10 @@ async def test_announce(
original_announce = entity.async_announce
announce_started = asyncio.Event()
async def async_announce(message, media_id):
async def async_announce(announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_announce(message, media_id)
await original_announce(announcement)
announce_started.set()
def tts_generate_media_source_id(
@ -249,7 +256,7 @@ async def test_announce_busy(
announce_started = asyncio.Event()
got_error = asyncio.Event()
async def async_announce(message, media_id):
async def async_announce(announcement):
announce_started.set()
# Block so we can do another announcement