diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 3ea46875fea..677d2a56664 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -677,6 +677,18 @@ class PipelineStorageCollectionWebsocket( """Set up the websocket commands.""" super().async_setup(hass, create_list=create_list, create_create=create_create) + websocket_api.async_register_command( + hass, + f"{self.api_prefix}/get", + self.ws_get_item, + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + vol.Required("type"): f"{self.api_prefix}/get", + vol.Optional(self.item_id_key): str, + } + ), + ) + websocket_api.async_register_command( hass, f"{self.api_prefix}/set_preferred", @@ -691,6 +703,36 @@ class PipelineStorageCollectionWebsocket( ), ) + async def ws_delete_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Delete an item.""" + try: + await super().ws_delete_item(hass, connection, msg) + except PipelinePreferred as exc: + connection.send_error( + msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc) + ) + + @callback + def ws_get_item( + self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Get an item.""" + item_id = msg.get(self.item_id_key) + if item_id is None: + item_id = self.storage_collection.async_get_preferred_item() + + if item_id not in self.storage_collection.data: + connection.send_error( + msg["id"], + websocket_api.const.ERR_NOT_FOUND, + f"Unable to find {self.item_id_key} {item_id}", + ) + return + + connection.send_result(msg["id"], self.storage_collection.data[item_id]) + @callback def ws_list_item( self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict @@ -704,17 +746,6 @@ class PipelineStorageCollectionWebsocket( }, ) - async def ws_delete_item( - self, hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict - ) -> None: - """Delete an item.""" - try: - await super().ws_delete_item(hass, connection, msg) - except PipelinePreferred as exc: - connection.send_error( - msg["id"], websocket_api.const.ERR_NOT_ALLOWED, str(exc) - ) - async def ws_set_preferred_item( self, hass: HomeAssistant, diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index ffef4b3192e..6de01e74ea9 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -718,6 +718,88 @@ async def test_delete_pipeline( } +async def test_get_pipeline( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can get a pipeline.""" + client = await hass_ws_client(hass) + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/get", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "Unable to find pipeline_id None", + } + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/get", + "pipeline_id": "no_such_pipeline", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_found", + "message": "Unable to find pipeline_id no_such_pipeline", + } + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + ) + msg = await client.receive_json() + assert msg["success"] + pipeline_id = msg["result"]["id"] + assert len(pipeline_store.data) == 1 + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/get", + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "test_conversation_engine", + "id": pipeline_id, + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/get", + "pipeline_id": pipeline_id, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "test_conversation_engine", + "id": pipeline_id, + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "tts_engine": "test_tts_engine", + } + + async def test_list_pipelines( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components ) -> None: