Add WS command for getting an assist pipeline (#91725)
* Add WS command for getting an assist pipeline * Return preferred pipeline if none is specifiedpull/91757/head
parent
a419c78524
commit
af193094b5
|
@ -677,6 +677,18 @@ class PipelineStorageCollectionWebsocket(
|
||||||
"""Set up the websocket commands."""
|
"""Set up the websocket commands."""
|
||||||
super().async_setup(hass, create_list=create_list, create_create=create_create)
|
super().async_setup(hass, create_list=create_list, create_create=create_create)
|
||||||
|
|
||||||
|
websocket_api.async_register_command(
|
||||||
|
hass,
|
||||||
|
f"{self.api_prefix}/get",
|
||||||
|
self.ws_get_item,
|
||||||
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||||
|
{
|
||||||
|
vol.Required("type"): f"{self.api_prefix}/get",
|
||||||
|
vol.Optional(self.item_id_key): str,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
websocket_api.async_register_command(
|
websocket_api.async_register_command(
|
||||||
hass,
|
hass,
|
||||||
f"{self.api_prefix}/set_preferred",
|
f"{self.api_prefix}/set_preferred",
|
||||||
|
@ -691,6 +703,36 @@ class PipelineStorageCollectionWebsocket(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def ws_delete_item(
|
||||||
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
|
) -> None:
|
||||||
|
"""Delete an item."""
|
||||||
|
try:
|
||||||
|
await super().ws_delete_item(hass, connection, msg)
|
||||||
|
except PipelinePreferred as exc:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
||||||
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def ws_get_item(
|
||||||
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
|
) -> None:
|
||||||
|
"""Get an item."""
|
||||||
|
item_id = msg.get(self.item_id_key)
|
||||||
|
if item_id is None:
|
||||||
|
item_id = self.storage_collection.async_get_preferred_item()
|
||||||
|
|
||||||
|
if item_id not in self.storage_collection.data:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"],
|
||||||
|
websocket_api.const.ERR_NOT_FOUND,
|
||||||
|
f"Unable to find {self.item_id_key} {item_id}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
connection.send_result(msg["id"], self.storage_collection.data[item_id])
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def ws_list_item(
|
def ws_list_item(
|
||||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
||||||
|
@ -704,17 +746,6 @@ class PipelineStorageCollectionWebsocket(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def ws_delete_item(
|
|
||||||
self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
|
|
||||||
) -> None:
|
|
||||||
"""Delete an item."""
|
|
||||||
try:
|
|
||||||
await super().ws_delete_item(hass, connection, msg)
|
|
||||||
except PipelinePreferred as exc:
|
|
||||||
connection.send_error(
|
|
||||||
msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def ws_set_preferred_item(
|
async def ws_set_preferred_item(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
|
|
@ -718,6 +718,88 @@ async def test_delete_pipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pipeline(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test we can get a pipeline."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "Unable to find pipeline_id None",
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
"pipeline_id": "no_such_pipeline",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"] == {
|
||||||
|
"code": "not_found",
|
||||||
|
"message": "Unable to find pipeline_id no_such_pipeline",
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"id": pipeline_id,
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"id": pipeline_id,
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def test_list_pipelines(
|
async def test_list_pipelines(
|
||||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
Loading…
Reference in New Issue