Add WS command for getting an assist pipeline (#91725)

* Add WS command for getting an assist pipeline

* Return preferred pipeline if none is specified
pull/91757/head
Erik Montnemery 2023-04-20 15:15:19 +02:00 committed by GitHub
parent a419c78524
commit af193094b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 11 deletions

View File

@ -677,6 +677,18 @@ class PipelineStorageCollectionWebsocket(
"""Set up the websocket commands.""" """Set up the websocket commands."""
super().async_setup(hass, create_list=create_list, create_create=create_create) super().async_setup(hass, create_list=create_list, create_create=create_create)
websocket_api.async_register_command(
hass,
f"{self.api_prefix}/get",
self.ws_get_item,
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
{
vol.Required("type"): f"{self.api_prefix}/get",
vol.Optional(self.item_id_key): str,
}
),
)
websocket_api.async_register_command( websocket_api.async_register_command(
hass, hass,
f"{self.api_prefix}/set_preferred", f"{self.api_prefix}/set_preferred",
@ -691,6 +703,36 @@ class PipelineStorageCollectionWebsocket(
), ),
) )
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete an item."""
try:
await super().ws_delete_item(hass, connection, msg)
except PipelinePreferred as exc:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
)
@callback
def ws_get_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Get an item."""
item_id = msg.get(self.item_id_key)
if item_id is None:
item_id = self.storage_collection.async_get_preferred_item()
if item_id not in self.storage_collection.data:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"Unable to find {self.item_id_key} {item_id}",
)
return
connection.send_result(msg["id"], self.storage_collection.data[item_id])
@callback @callback
def ws_list_item( def ws_list_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
@ -704,17 +746,6 @@ class PipelineStorageCollectionWebsocket(
}, },
) )
async def ws_delete_item(
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Delete an item."""
try:
await super().ws_delete_item(hass, connection, msg)
except PipelinePreferred as exc:
connection.send_error(
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
)
async def ws_set_preferred_item( async def ws_set_preferred_item(
self, self,
hass: HomeAssistant, hass: HomeAssistant,

View File

@ -718,6 +718,88 @@ async def test_delete_pipeline(
} }
async def test_get_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None:
"""Test we can get a pipeline."""
client = await hass_ws_client(hass)
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Unable to find pipeline_id None",
}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
"pipeline_id": "no_such_pipeline",
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Unable to find pipeline_id no_such_pipeline",
}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"id": pipeline_id,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
"pipeline_id": pipeline_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"id": pipeline_id,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"tts_engine": "test_tts_engine",
}
async def test_list_pipelines( async def test_list_pipelines(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
) -> None: ) -> None: