Drop language parameter from async_get_pipeline (#91612)

pull/91600/head
Erik Montnemery 2023-04-18 18:07:20 +02:00 committed by GitHub
parent 10606c4d1e
commit bdffb1f298
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 58 deletions

View File

@ -45,28 +45,16 @@ async def async_pipeline_from_audio_stream(
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
language: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
context: Context | None = None,
tts_options: dict | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if language is None and pipeline_id is None:
language = hass.config.language
# Temporary workaround for language codes
if language == "en":
language = "en-US"
if context is None:
context = Context()
pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"

View File

@ -53,7 +53,7 @@ SAVE_DELAY = 10
async def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None, language: str | None = None
hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
pipeline_data: PipelineData = hass.data[DOMAIN]
@ -64,12 +64,11 @@ async def async_get_pipeline(
if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the
# required/configured language
language = language or hass.config.language
# configured language
return await pipeline_data.pipeline_store.async_create_item(
{
"name": language,
"language": language,
"name": hass.config.language,
"language": hass.config.language,
"stt_engine": None, # first engine
"conversation_engine": None, # first agent
"tts_engine": None, # first engine

View File

@ -46,7 +46,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
# pylint: disable-next=unnecessary-lambda
vol.Required("end_stage"): lambda val: PipelineStage(val),
vol.Optional("input"): dict,
vol.Optional("language"): str,
vol.Optional("pipeline"): str,
vol.Optional("conversation_id"): vol.Any(str, None),
vol.Optional("timeout"): vol.Any(float, int),
@ -82,23 +81,13 @@ async def websocket_run(
msg: dict[str, Any],
) -> None:
"""Run a pipeline."""
language = msg.get("language", hass.config.language)
# Temporary workaround for language codes
if language == "en":
language = "en-US"
pipeline_id = msg.get("pipeline")
pipeline = await async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
connection.send_error(
msg["id"],
"pipeline-not-found",
f"Pipeline not found: id={pipeline_id}, language={language}",
f"Pipeline not found: id={pipeline_id}",
)
return
@ -147,7 +136,7 @@ async def websocket_run(
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata(
language=language,
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,

View File

@ -3,8 +3,8 @@
list([
dict({
'data': dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -47,7 +47,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -70,7 +70,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),

View File

@ -1,8 +1,8 @@
# serializer version: 1
# name: test_audio_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -45,7 +45,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -66,7 +66,7 @@
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -74,8 +74,8 @@
# ---
# name: test_audio_pipeline_debug
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -118,7 +118,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -139,7 +139,7 @@
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -147,8 +147,8 @@
# ---
# name: test_intent_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
@ -163,8 +163,8 @@
# ---
# name: test_intent_timeout
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 0.1,
@ -185,8 +185,8 @@
# ---
# name: test_stt_provider_missing
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -201,15 +201,15 @@
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'language': 'en',
'sample_rate': 16000,
}),
})
# ---
# name: test_stt_stream_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -231,8 +231,8 @@
# ---
# name: test_text_only_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
@ -255,7 +255,7 @@
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
@ -275,8 +275,8 @@
# ---
# name: test_tts_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'language': 'en',
'pipeline': 'en',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,