156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
"""Support for Wyoming text to speech services."""
|
|
from collections import defaultdict
|
|
import io
|
|
import logging
|
|
import wave
|
|
|
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
|
from wyoming.client import AsyncTcpClient
|
|
from wyoming.tts import Synthesize
|
|
|
|
from homeassistant.components import tts
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
|
|
from .const import DOMAIN
|
|
from .data import WyomingService
|
|
from .error import WyomingError
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up Wyoming speech to text."""
|
|
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
|
async_add_entities(
|
|
[
|
|
WyomingTtsProvider(config_entry, service),
|
|
]
|
|
)
|
|
|
|
|
|
class WyomingTtsProvider(tts.TextToSpeechEntity):
|
|
"""Wyoming text to speech provider."""
|
|
|
|
def __init__(
|
|
self,
|
|
config_entry: ConfigEntry,
|
|
service: WyomingService,
|
|
) -> None:
|
|
"""Set up provider."""
|
|
self.service = service
|
|
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
|
|
|
voice_languages: set[str] = set()
|
|
self._voices: dict[str, list[tts.Voice]] = defaultdict(list)
|
|
for voice in self._tts_service.voices:
|
|
if not voice.installed:
|
|
continue
|
|
|
|
voice_languages.update(voice.languages)
|
|
for language in voice.languages:
|
|
self._voices[language].append(
|
|
tts.Voice(
|
|
voice_id=voice.name,
|
|
name=voice.name,
|
|
)
|
|
)
|
|
|
|
self._supported_languages: list[str] = list(voice_languages)
|
|
|
|
self._attr_name = self._tts_service.name
|
|
self._attr_unique_id = f"{config_entry.entry_id}-tts"
|
|
|
|
@property
|
|
def default_language(self):
|
|
"""Return default language."""
|
|
if not self._supported_languages:
|
|
return None
|
|
|
|
return self._supported_languages[0]
|
|
|
|
@property
|
|
def supported_languages(self):
|
|
"""Return list of supported languages."""
|
|
return self._supported_languages
|
|
|
|
@property
|
|
def supported_options(self):
|
|
"""Return list of supported options like voice, emotion."""
|
|
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
|
|
|
@property
|
|
def default_options(self):
|
|
"""Return a dict include default options."""
|
|
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
|
|
|
@callback
|
|
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
|
"""Return a list of supported voices for a language."""
|
|
return self._voices.get(language)
|
|
|
|
async def async_get_tts_audio(self, message, language, options=None):
|
|
"""Load TTS from UNIX socket."""
|
|
try:
|
|
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
|
await client.write_event(Synthesize(message).event())
|
|
|
|
with io.BytesIO() as wav_io:
|
|
wav_writer: wave.Wave_write | None = None
|
|
while True:
|
|
event = await client.read_event()
|
|
if event is None:
|
|
_LOGGER.debug("Connection lost")
|
|
return (None, None)
|
|
|
|
if AudioStop.is_type(event.type):
|
|
break
|
|
|
|
if AudioChunk.is_type(event.type):
|
|
chunk = AudioChunk.from_event(event)
|
|
if wav_writer is None:
|
|
wav_writer = wave.open(wav_io, "wb")
|
|
wav_writer.setframerate(chunk.rate)
|
|
wav_writer.setsampwidth(chunk.width)
|
|
wav_writer.setnchannels(chunk.channels)
|
|
|
|
wav_writer.writeframes(chunk.audio)
|
|
|
|
if wav_writer is not None:
|
|
wav_writer.close()
|
|
|
|
data = wav_io.getvalue()
|
|
|
|
except (OSError, WyomingError):
|
|
return (None, None)
|
|
|
|
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
|
return ("wav", data)
|
|
|
|
# Raw output (convert to 16Khz, 16-bit mono)
|
|
with io.BytesIO(data) as wav_io:
|
|
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
|
raw_data = (
|
|
AudioChunkConverter(
|
|
rate=16000,
|
|
width=2,
|
|
channels=1,
|
|
)
|
|
.convert(
|
|
AudioChunk(
|
|
audio=wav_reader.readframes(wav_reader.getnframes()),
|
|
rate=wav_reader.getframerate(),
|
|
width=wav_reader.getsampwidth(),
|
|
channels=wav_reader.getnchannels(),
|
|
)
|
|
)
|
|
.audio
|
|
)
|
|
|
|
return ("raw", raw_data)
|