Change assist satellite announce method signature (#126299)
parent
41ffa8d6db
commit
604c848dec
|
@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||
|
||||
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
|
||||
from .entity import (
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
|
@ -22,6 +23,7 @@ from .websocket_api import async_register_websocket_api
|
|||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteAnnouncement",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteConfiguration",
|
||||
"AssistSatelliteEntityDescription",
|
||||
|
|
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||
from enum import StrEnum
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final, final
|
||||
from typing import Any, Final, Literal, final
|
||||
|
||||
from homeassistant.components import media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
|
@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
|
|||
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistSatelliteAnnouncement:
|
||||
"""Announcement to be made."""
|
||||
|
||||
message: str
|
||||
"""Message to be spoken."""
|
||||
|
||||
media_id: str
|
||||
"""Media ID to be played."""
|
||||
|
||||
media_id_source: Literal["url", "media_id", "tts"]
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
|
@ -174,10 +187,13 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
"""
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||
|
||||
if message is None:
|
||||
message = ""
|
||||
|
||||
if not media_id:
|
||||
media_id_source = "tts"
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
@ -198,6 +214,8 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
if not media_id_source:
|
||||
media_id_source = "media_id"
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
|
@ -205,6 +223,9 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
)
|
||||
media_id = media.url
|
||||
|
||||
if not media_id_source:
|
||||
media_id_source = "url"
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
|
@ -216,12 +237,14 @@ class AssistSatelliteEntity(entity.Entity):
|
|||
|
||||
try:
|
||||
# Block until announcement is finished
|
||||
await self.async_announce(message, media_id)
|
||||
await self.async_announce(
|
||||
AssistSatelliteAnnouncement(message, media_id, media_id_source)
|
||||
)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
|
|
|
@ -313,18 +313,20 @@ class EsphomeAssistSatellite(
|
|||
|
||||
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
async def async_announce(
|
||||
self, announcement: assist_satellite.AssistSatelliteAnnouncement
|
||||
) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
_LOGGER.debug(
|
||||
"Waiting for announcement to finished (message=%s, media_id=%s)",
|
||||
message,
|
||||
media_id,
|
||||
announcement.message,
|
||||
announcement.media_id,
|
||||
)
|
||||
await self.cli.send_voice_assistant_announcement_await_response(
|
||||
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message
|
||||
announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
|
||||
)
|
||||
|
||||
async def handle_pipeline_start(
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
|
@ -63,9 +64,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||
"""Handle pipeline events."""
|
||||
self.events.append(event)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||
"""Announce media on a device."""
|
||||
self.announcements.append((message, media_id))
|
||||
self.announcements.append(announcement)
|
||||
|
||||
@callback
|
||||
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
||||
|
|
|
@ -17,7 +17,10 @@ from homeassistant.components.assist_pipeline import (
|
|||
async_update_pipeline,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import SatelliteBusyError
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteAnnouncement,
|
||||
SatelliteBusyError,
|
||||
)
|
||||
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
@ -159,18 +162,22 @@ async def test_new_pipeline_cancels_pipeline(
|
|||
[
|
||||
(
|
||||
{"message": "Hello"},
|
||||
("Hello", "https://www.home-assistant.io/resolved.mp3"),
|
||||
AssistSatelliteAnnouncement(
|
||||
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
|
||||
),
|
||||
),
|
||||
(
|
||||
{
|
||||
"message": "Hello",
|
||||
"media_id": "http://example.com/bla.mp3",
|
||||
"media_id": "media-source://bla",
|
||||
},
|
||||
("Hello", "http://example.com/bla.mp3"),
|
||||
AssistSatelliteAnnouncement(
|
||||
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
|
||||
),
|
||||
),
|
||||
(
|
||||
{"media_id": "http://example.com/bla.mp3"},
|
||||
("", "http://example.com/bla.mp3"),
|
||||
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -195,10 +202,10 @@ async def test_announce(
|
|||
original_announce = entity.async_announce
|
||||
announce_started = asyncio.Event()
|
||||
|
||||
async def async_announce(message, media_id):
|
||||
async def async_announce(announcement):
|
||||
# Verify state change
|
||||
assert entity.state == AssistSatelliteState.RESPONDING
|
||||
await original_announce(message, media_id)
|
||||
await original_announce(announcement)
|
||||
announce_started.set()
|
||||
|
||||
def tts_generate_media_source_id(
|
||||
|
@ -249,7 +256,7 @@ async def test_announce_busy(
|
|||
announce_started = asyncio.Event()
|
||||
got_error = asyncio.Event()
|
||||
|
||||
async def async_announce(message, media_id):
|
||||
async def async_announce(announcement):
|
||||
announce_started.set()
|
||||
|
||||
# Block so we can do another announcement
|
||||
|
|
Loading…
Reference in New Issue