Send language to Wyoming STT (#97344)
parent
5aa3e36754
commit
8ad37d7640
|
@ -2,7 +2,7 @@
|
|||
from collections.abc import AsyncIterable
|
||||
import logging
|
||||
|
||||
from wyoming.asr import Transcript
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
|
||||
|
@ -89,6 +89,10 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
|
|||
"""Process an audio stream to STT service."""
|
||||
try:
|
||||
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||
# Set transcription language
|
||||
await client.write_event(Transcribe(language=metadata.language).event())
|
||||
|
||||
# Begin audio stream
|
||||
await client.write_event(
|
||||
AudioStart(
|
||||
rate=SAMPLE_RATE,
|
||||
|
@ -106,6 +110,7 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
|
|||
)
|
||||
await client.write_event(chunk.event())
|
||||
|
||||
# End audio stream
|
||||
await client.write_event(AudioStop().event())
|
||||
|
||||
while True:
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
@ -69,3 +70,16 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
|
|||
return_value=TTS_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
|
||||
"""Get default STT metadata."""
|
||||
return stt.SpeechMetadata(
|
||||
language=hass.config.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
# serializer version: 1
|
||||
# name: test_streaming_audio
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'transcibe',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'channels': 1,
|
||||
|
|
|
@ -27,7 +27,9 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
|||
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]
|
||||
|
||||
|
||||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||
async def test_streaming_audio(
|
||||
hass: HomeAssistant, init_wyoming_stt, metadata, snapshot
|
||||
) -> None:
|
||||
"""Test streaming audio."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
@ -40,7 +42,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
|
|||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
|
||||
) as mock_client:
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.SUCCESS
|
||||
assert result.text == "Hello world"
|
||||
|
@ -48,7 +50,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
|
|||
|
||||
|
||||
async def test_streaming_audio_connection_lost(
|
||||
hass: HomeAssistant, init_wyoming_stt
|
||||
hass: HomeAssistant, init_wyoming_stt, metadata
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
|
@ -61,13 +63,15 @@ async def test_streaming_audio_connection_lost(
|
|||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
MockAsyncTcpClient([None]),
|
||||
):
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
||||
|
||||
|
||||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
async def test_streaming_audio_oserror(
|
||||
hass: HomeAssistant, init_wyoming_stt, metadata
|
||||
) -> None:
|
||||
"""Test streaming audio and error raising."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
@ -81,7 +85,7 @@ async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) ->
|
|||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||
mock_client,
|
||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||
result = await entity.async_process_audio_stream(metadata, audio_stream())
|
||||
|
||||
assert result.result == stt.SpeechResultState.ERROR
|
||||
assert result.text is None
|
||||
|
|
Loading…
Reference in New Issue