Send language to Wyoming STT (#97344)

pull/97536/head^2
Michael Hansen 2023-08-01 03:05:01 -05:00 committed by GitHub
parent 5aa3e36754
commit 8ad37d7640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 7 deletions

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncIterable
import logging
from wyoming.asr import Transcript
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
@ -89,6 +89,10 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
"""Process an audio stream to STT service."""
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Set transcription language
await client.write_event(Transcribe(language=metadata.language).event())
# Begin audio stream
await client.write_event(
AudioStart(
rate=SAMPLE_RATE,
@ -106,6 +110,7 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
)
await client.write_event(chunk.event())
# End audio stream
await client.write_event(AudioStop().event())
while True:

View File

@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
@ -69,3 +70,16 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
return_value=TTS_INFO,
):
await hass.config_entries.async_setup(tts_config_entry.entry_id)
@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""
return stt.SpeechMetadata(
language=hass.config.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)

View File

@ -1,6 +1,13 @@
# serializer version: 1
# name: test_streaming_audio
list([
dict({
'data': dict({
'language': 'en',
}),
'payload': None,
'type': 'transcibe',
}),
dict({
'data': dict({
'channels': 1,

View File

@ -27,7 +27,9 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
async def test_streaming_audio(
hass: HomeAssistant, init_wyoming_stt, metadata, snapshot
) -> None:
"""Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
@ -40,7 +42,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
) as mock_client:
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "Hello world"
@ -48,7 +50,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt
hass: HomeAssistant, init_wyoming_stt, metadata
) -> None:
"""Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
@ -61,13 +63,15 @@ async def test_streaming_audio_connection_lost(
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
async def test_streaming_audio_oserror(
hass: HomeAssistant, init_wyoming_stt, metadata
) -> None:
"""Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
@ -81,7 +85,7 @@ async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) ->
"homeassistant.components.wyoming.stt.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.ERROR
assert result.text is None