2023-03-17 00:42:26 +00:00
|
|
|
"""Websocket tests for Voice Assistant integration."""
|
|
|
|
import asyncio
|
2023-04-06 16:55:16 +00:00
|
|
|
from unittest.mock import ANY, MagicMock, patch
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-03-27 02:41:17 +00:00
|
|
|
from syrupy.assertion import SnapshotAssertion
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-04-13 21:25:38 +00:00
|
|
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
|
|
|
from homeassistant.components.assist_pipeline.pipeline import (
|
2023-04-06 16:55:16 +00:00
|
|
|
Pipeline,
|
|
|
|
PipelineStorageCollection,
|
|
|
|
)
|
2023-03-17 00:42:26 +00:00
|
|
|
from homeassistant.core import HomeAssistant
|
2023-03-28 13:22:48 +00:00
|
|
|
|
2023-04-04 04:06:51 +00:00
|
|
|
from tests.typing import WebSocketGenerator
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_text_only_pipeline(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
2023-04-06 16:55:16 +00:00
|
|
|
init_components,
|
2023-03-27 02:41:17 +00:00
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-17 00:42:26 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-22 01:10:31 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "intent",
|
|
|
|
"end_stage": "intent",
|
|
|
|
"input": {"text": "Are the lights on?"},
|
2023-03-22 01:10:31 +00:00
|
|
|
}
|
2023-03-17 00:42:26 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
# intent
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "intent-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
msg = await client.receive_json()
|
2023-03-23 18:44:19 +00:00
|
|
|
assert msg["event"]["type"] == "intent-end"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# run end
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-end"
|
2023-04-17 14:33:53 +00:00
|
|
|
assert msg["event"]["data"] is None
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_audio_pipeline(
|
2023-04-06 16:55:16 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test events from a pipeline run with audio input/output."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "stt",
|
|
|
|
"end_stage": "tts",
|
2023-04-10 23:28:03 +00:00
|
|
|
"input": {
|
|
|
|
"sample_rate": 44100,
|
|
|
|
},
|
2023-03-23 18:44:19 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# stt
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "stt-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# End of audio stream (handler id + empty payload)
|
2023-04-04 04:06:51 +00:00
|
|
|
await client.send_bytes(bytes([1]))
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "stt-end"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# intent
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "intent-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "intent-end"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
# text to speech
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "tts-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "tts-end"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# run end
|
2023-03-17 00:42:26 +00:00
|
|
|
msg = await client.receive_json()
|
2023-03-23 18:44:19 +00:00
|
|
|
assert msg["event"]["type"] == "run-end"
|
2023-04-17 14:33:53 +00:00
|
|
|
assert msg["event"]["data"] is None
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
async def test_intent_timeout(
|
2023-03-27 02:41:17 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-17 00:42:26 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test partial pipeline run with conversation agent timeout."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
async def sleepy_converse(*args, **kwargs):
|
|
|
|
await asyncio.sleep(3600)
|
|
|
|
|
|
|
|
with patch(
|
2023-03-22 01:10:31 +00:00
|
|
|
"homeassistant.components.conversation.async_converse",
|
|
|
|
new=sleepy_converse,
|
2023-03-17 00:42:26 +00:00
|
|
|
):
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-17 00:42:26 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "intent",
|
|
|
|
"end_stage": "intent",
|
|
|
|
"input": {"text": "Are the lights on?"},
|
2023-03-31 19:04:22 +00:00
|
|
|
"timeout": 0.1,
|
2023-03-17 00:42:26 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
# intent
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "intent-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
# timeout error
|
|
|
|
msg = await client.receive_json()
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-17 00:42:26 +00:00
|
|
|
|
|
|
|
|
2023-03-23 18:44:19 +00:00
|
|
|
async def test_text_pipeline_timeout(
|
2023-03-27 02:41:17 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test text-only pipeline run with immediate timeout."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
async def sleepy_run(*args, **kwargs):
|
|
|
|
await asyncio.sleep(3600)
|
|
|
|
|
|
|
|
with patch(
|
2023-04-13 21:25:38 +00:00
|
|
|
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.execute",
|
2023-03-23 18:44:19 +00:00
|
|
|
new=sleepy_run,
|
|
|
|
):
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "intent",
|
|
|
|
"end_stage": "intent",
|
|
|
|
"input": {"text": "Are the lights on?"},
|
|
|
|
"timeout": 0.0001,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# timeout error
|
|
|
|
msg = await client.receive_json()
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_intent_failed(
|
2023-03-27 02:41:17 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test text-only pipeline run with conversation agent error."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
with patch(
|
|
|
|
"homeassistant.components.conversation.async_converse",
|
|
|
|
new=MagicMock(return_value=RuntimeError),
|
|
|
|
):
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "intent",
|
|
|
|
"end_stage": "intent",
|
|
|
|
"input": {"text": "Are the lights on?"},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# intent start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "intent-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# intent error
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"]["code"] == "intent-failed"
|
|
|
|
|
|
|
|
|
|
|
|
async def test_audio_pipeline_timeout(
|
2023-03-27 02:41:17 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-17 00:42:26 +00:00
|
|
|
) -> None:
|
2023-03-23 18:44:19 +00:00
|
|
|
"""Test audio pipeline run with immediate timeout."""
|
2023-03-17 00:42:26 +00:00
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
async def sleepy_run(*args, **kwargs):
|
|
|
|
await asyncio.sleep(3600)
|
|
|
|
|
|
|
|
with patch(
|
2023-04-13 21:25:38 +00:00
|
|
|
"homeassistant.components.assist_pipeline.pipeline.PipelineInput.execute",
|
2023-03-17 00:42:26 +00:00
|
|
|
new=sleepy_run,
|
|
|
|
):
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-17 00:42:26 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "stt",
|
|
|
|
"end_stage": "tts",
|
2023-04-10 23:28:03 +00:00
|
|
|
"input": {
|
|
|
|
"sample_rate": 44100,
|
|
|
|
},
|
2023-03-17 00:42:26 +00:00
|
|
|
"timeout": 0.0001,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# timeout error
|
|
|
|
msg = await client.receive_json()
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"]["code"] == "timeout"
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_stt_provider_missing(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
2023-04-06 16:55:16 +00:00
|
|
|
init_components,
|
2023-03-27 02:41:17 +00:00
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
|
|
|
with patch(
|
|
|
|
"homeassistant.components.stt.async_get_provider",
|
|
|
|
new=MagicMock(return_value=None),
|
|
|
|
):
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "stt",
|
|
|
|
"end_stage": "tts",
|
2023-04-10 23:28:03 +00:00
|
|
|
"input": {
|
|
|
|
"sample_rate": 44100,
|
|
|
|
},
|
2023-03-23 18:44:19 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
2023-03-31 19:04:22 +00:00
|
|
|
assert not msg["success"]
|
|
|
|
assert msg["error"]["code"] == "stt-provider-missing"
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_stt_stream_failed(
|
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
2023-04-06 16:55:16 +00:00
|
|
|
init_components,
|
2023-03-27 02:41:17 +00:00
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test events from a pipeline run with a non-existent STT provider."""
|
|
|
|
with patch(
|
2023-04-13 21:25:38 +00:00
|
|
|
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
|
2023-03-23 18:44:19 +00:00
|
|
|
new=MagicMock(side_effect=RuntimeError),
|
|
|
|
):
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "stt",
|
|
|
|
"end_stage": "tts",
|
2023-04-10 23:28:03 +00:00
|
|
|
"input": {
|
|
|
|
"sample_rate": 44100,
|
|
|
|
},
|
2023-03-23 18:44:19 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# stt
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "stt-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# End of audio stream (handler id + empty payload)
|
|
|
|
await client.send_bytes(b"1")
|
|
|
|
|
|
|
|
# stt error
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"]["code"] == "stt-stream-failed"
|
|
|
|
|
|
|
|
|
|
|
|
async def test_tts_failed(
|
2023-03-27 02:41:17 +00:00
|
|
|
hass: HomeAssistant,
|
|
|
|
hass_ws_client: WebSocketGenerator,
|
|
|
|
init_components,
|
|
|
|
snapshot: SnapshotAssertion,
|
2023-03-23 18:44:19 +00:00
|
|
|
) -> None:
|
|
|
|
"""Test pipeline run with text to speech error."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
with patch(
|
|
|
|
"homeassistant.components.media_source.async_resolve_media",
|
|
|
|
new=MagicMock(return_value=RuntimeError),
|
|
|
|
):
|
|
|
|
await client.send_json(
|
|
|
|
{
|
|
|
|
"id": 5,
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "tts",
|
|
|
|
"end_stage": "tts",
|
|
|
|
"input": {"text": "Lights are on."},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
# run start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "run-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# tts start
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "tts-start"
|
2023-03-27 02:41:17 +00:00
|
|
|
assert msg["event"]["data"] == snapshot
|
2023-03-23 18:44:19 +00:00
|
|
|
|
|
|
|
# tts error
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["event"]["type"] == "error"
|
|
|
|
assert msg["event"]["data"]["code"] == "tts-failed"
|
|
|
|
|
|
|
|
|
|
|
|
async def test_invalid_stage_order(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test pipeline run with invalid stage order."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
2023-04-06 00:07:42 +00:00
|
|
|
await client.send_json_auto_id(
|
2023-03-23 18:44:19 +00:00
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/run",
|
2023-03-23 18:44:19 +00:00
|
|
|
"start_stage": "tts",
|
|
|
|
"end_stage": "stt",
|
|
|
|
"input": {"text": "Lights are on."},
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# result
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert not msg["success"]
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_add_pipeline(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test we can add a pipeline."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/create",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
assert msg["result"] == {
|
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"id": ANY,
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
|
|
|
|
assert len(pipeline_store.data) == 1
|
|
|
|
pipeline = pipeline_store.data[msg["result"]["id"]]
|
|
|
|
assert pipeline == Pipeline(
|
|
|
|
conversation_engine="test_conversation_engine",
|
|
|
|
id=msg["result"]["id"],
|
|
|
|
language="test_language",
|
|
|
|
name="test_name",
|
|
|
|
stt_engine="test_stt_engine",
|
|
|
|
tts_engine="test_tts_engine",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def test_delete_pipeline(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test we can delete a pipeline."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/create",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
2023-04-15 14:05:46 +00:00
|
|
|
pipeline_id_1 = msg["result"]["id"]
|
2023-04-06 16:55:16 +00:00
|
|
|
|
2023-04-15 14:05:46 +00:00
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
|
|
|
"type": "assist_pipeline/pipeline/create",
|
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
pipeline_id_2 = msg["result"]["id"]
|
|
|
|
|
|
|
|
assert len(pipeline_store.data) == 2
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/delete",
|
2023-04-15 14:05:46 +00:00
|
|
|
"pipeline_id": pipeline_id_1,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert not msg["success"]
|
|
|
|
assert msg["error"] == {
|
|
|
|
"code": "not_allowed",
|
|
|
|
"message": f"Item {pipeline_id_1} preferred.",
|
|
|
|
}
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
|
|
|
"type": "assist_pipeline/pipeline/delete",
|
|
|
|
"pipeline_id": pipeline_id_2,
|
2023-04-06 16:55:16 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
2023-04-15 14:05:46 +00:00
|
|
|
assert len(pipeline_store.data) == 1
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/delete",
|
2023-04-15 14:05:46 +00:00
|
|
|
"pipeline_id": pipeline_id_2,
|
2023-04-06 16:55:16 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert not msg["success"]
|
|
|
|
assert msg["error"] == {
|
|
|
|
"code": "not_found",
|
2023-04-15 14:05:46 +00:00
|
|
|
"message": f"Unable to find pipeline_id {pipeline_id_2}",
|
2023-04-06 16:55:16 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
async def test_list_pipelines(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test we can list pipelines."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
|
|
|
|
2023-04-13 21:25:38 +00:00
|
|
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
2023-04-06 16:55:16 +00:00
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
2023-04-15 14:05:46 +00:00
|
|
|
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/create",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
assert len(pipeline_store.data) == 1
|
|
|
|
|
2023-04-13 21:25:38 +00:00
|
|
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
2023-04-06 16:55:16 +00:00
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
2023-04-15 14:05:46 +00:00
|
|
|
assert msg["result"] == {
|
|
|
|
"pipelines": [
|
|
|
|
{
|
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"id": ANY,
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"preferred_pipeline": ANY,
|
|
|
|
}
|
2023-04-06 16:55:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_update_pipeline(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test we can list pipelines."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/update",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "new_conversation_engine",
|
|
|
|
"language": "new_language",
|
|
|
|
"name": "new_name",
|
|
|
|
"pipeline_id": "no_such_pipeline",
|
|
|
|
"stt_engine": "new_stt_engine",
|
|
|
|
"tts_engine": "new_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert not msg["success"]
|
|
|
|
assert msg["error"] == {
|
|
|
|
"code": "not_found",
|
|
|
|
"message": "Unable to find pipeline_id no_such_pipeline",
|
|
|
|
}
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/create",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
pipeline_id = msg["result"]["id"]
|
|
|
|
assert len(pipeline_store.data) == 1
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
2023-04-13 21:25:38 +00:00
|
|
|
"type": "assist_pipeline/pipeline/update",
|
2023-04-06 16:55:16 +00:00
|
|
|
"conversation_engine": "new_conversation_engine",
|
|
|
|
"language": "new_language",
|
|
|
|
"name": "new_name",
|
|
|
|
"pipeline_id": pipeline_id,
|
|
|
|
"stt_engine": "new_stt_engine",
|
|
|
|
"tts_engine": "new_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
assert msg["result"] == {
|
|
|
|
"conversation_engine": "new_conversation_engine",
|
2023-04-15 14:05:46 +00:00
|
|
|
"id": pipeline_id,
|
2023-04-06 16:55:16 +00:00
|
|
|
"language": "new_language",
|
|
|
|
"name": "new_name",
|
|
|
|
"stt_engine": "new_stt_engine",
|
|
|
|
"tts_engine": "new_tts_engine",
|
|
|
|
}
|
|
|
|
|
|
|
|
assert len(pipeline_store.data) == 1
|
|
|
|
pipeline = pipeline_store.data[pipeline_id]
|
|
|
|
assert pipeline == Pipeline(
|
|
|
|
conversation_engine="new_conversation_engine",
|
|
|
|
id=pipeline_id,
|
|
|
|
language="new_language",
|
|
|
|
name="new_name",
|
|
|
|
stt_engine="new_stt_engine",
|
|
|
|
tts_engine="new_tts_engine",
|
|
|
|
)
|
2023-04-15 14:05:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_set_preferred_pipeline(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test updating the preferred pipeline."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
|
|
|
"type": "assist_pipeline/pipeline/create",
|
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
pipeline_id_1 = msg["result"]["id"]
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
|
|
|
"type": "assist_pipeline/pipeline/create",
|
|
|
|
"conversation_engine": "test_conversation_engine",
|
|
|
|
"language": "test_language",
|
|
|
|
"name": "test_name",
|
|
|
|
"stt_engine": "test_stt_engine",
|
|
|
|
"tts_engine": "test_tts_engine",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
pipeline_id_2 = msg["result"]["id"]
|
|
|
|
|
|
|
|
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{
|
|
|
|
"type": "assist_pipeline/pipeline/set_preferred",
|
|
|
|
"pipeline_id": pipeline_id_2,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["success"]
|
|
|
|
|
|
|
|
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
|
|
|
|
|
|
|
|
|
|
|
|
async def test_set_preferred_pipeline_wrong_id(
|
|
|
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
|
|
|
) -> None:
|
|
|
|
"""Test updating the preferred pipeline."""
|
|
|
|
client = await hass_ws_client(hass)
|
|
|
|
|
|
|
|
await client.send_json_auto_id(
|
|
|
|
{"type": "assist_pipeline/pipeline/set_preferred", "pipeline_id": "don_t_exist"}
|
|
|
|
)
|
|
|
|
msg = await client.receive_json()
|
|
|
|
assert msg["error"]["code"] == "not_found"
|