Add WS command for getting an assist pipeline (#91725)
* Add WS command for getting an assist pipeline * Return preferred pipeline if none is specifiedpull/91757/head
parent
a419c78524
commit
af193094b5
|
@ -677,6 +677,18 @@ class PipelineStorageCollectionWebsocket(
|
|||
"""Set up the websocket commands."""
|
||||
super().async_setup(hass, create_list=create_list, create_create=create_create)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/get",
|
||||
self.ws_get_item,
|
||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required("type"): f"{self.api_prefix}/get",
|
||||
vol.Optional(self.item_id_key): str,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
f"{self.api_prefix}/set_preferred",
|
||||
|
@ -691,6 +703,36 @@ class PipelineStorageCollectionWebsocket(
|
|||
),
|
||||
)
|
||||
|
||||
async def ws_delete_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Delete an item."""
|
||||
try:
|
||||
await super().ws_delete_item(hass, connection, msg)
|
||||
except PipelinePreferred as exc:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
||||
)
|
||||
|
||||
@callback
|
||||
def ws_get_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Get an item."""
|
||||
item_id = msg.get(self.item_id_key)
|
||||
if item_id is None:
|
||||
item_id = self.storage_collection.async_get_preferred_item()
|
||||
|
||||
if item_id not in self.storage_collection.data:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
websocket_api.const.ERR_NOT_FOUND,
|
||||
f"Unable to find {self.item_id_key} {item_id}",
|
||||
)
|
||||
return
|
||||
|
||||
connection.send_result(msg["id"], self.storage_collection.data[item_id])
|
||||
|
||||
@callback
|
||||
def ws_list_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
|
@ -704,17 +746,6 @@ class PipelineStorageCollectionWebsocket(
|
|||
},
|
||||
)
|
||||
|
||||
async def ws_delete_item(
|
||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||
) -> None:
|
||||
"""Delete an item."""
|
||||
try:
|
||||
await super().ws_delete_item(hass, connection, msg)
|
||||
except PipelinePreferred as exc:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
||||
)
|
||||
|
||||
async def ws_set_preferred_item(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
|
|
|
@ -718,6 +718,88 @@ async def test_delete_pipeline(
|
|||
}
|
||||
|
||||
|
||||
async def test_get_pipeline(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
"""Test we can get a pipeline."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Unable to find pipeline_id None",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
"pipeline_id": "no_such_pipeline",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Unable to find pipeline_id no_such_pipeline",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": pipeline_id,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"id": pipeline_id,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"tts_engine": "test_tts_engine",
|
||||
}
|
||||
|
||||
|
||||
async def test_list_pipelines(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue