From 4da04a358ab00ea471f86198aa6689b6bb74f3b2 Mon Sep 17 00:00:00 2001 From: Martin Hjelmare Date: Thu, 14 Dec 2023 23:56:08 +0100 Subject: [PATCH] Refactor cloud assist pipeline (#105723) * Refactor cloud assist pipeline * Return None early --- .../components/cloud/assist_pipeline.py | 44 +++++++++++++++++++ homeassistant/components/cloud/http_api.py | 34 +++----------- tests/components/cloud/test_http_api.py | 6 +-- 3 files changed, 54 insertions(+), 30 deletions(-) create mode 100644 homeassistant/components/cloud/assist_pipeline.py diff --git a/homeassistant/components/cloud/assist_pipeline.py b/homeassistant/components/cloud/assist_pipeline.py new file mode 100644 index 00000000000..8054b3bd953 --- /dev/null +++ b/homeassistant/components/cloud/assist_pipeline.py @@ -0,0 +1,44 @@ +"""Handle Cloud assist pipelines.""" +from homeassistant.components.assist_pipeline import ( + async_create_default_pipeline, + async_get_pipelines, + async_setup_pipeline_store, +) +from homeassistant.components.conversation import HOME_ASSISTANT_AGENT +from homeassistant.core import HomeAssistant + +from .const import DOMAIN + + +async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None: + """Create a cloud assist pipeline.""" + # Make sure the pipeline store is loaded, needed because assist_pipeline + # is an after dependency of cloud + await async_setup_pipeline_store(hass) + + def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: + """Return the ID of a cloud-enabled assist pipeline or None. + + Check if a cloud pipeline already exists with + legacy cloud engine id. + """ + for pipeline in async_get_pipelines(hass): + if ( + pipeline.conversation_engine == HOME_ASSISTANT_AGENT + and pipeline.stt_engine == DOMAIN + and pipeline.tts_engine == DOMAIN + ): + return pipeline.id + return None + + if (cloud_assist_pipeline(hass)) is not None or ( + cloud_pipeline := await async_create_default_pipeline( + hass, + stt_engine_id=DOMAIN, + tts_engine_id=DOMAIN, + pipeline_name="Home Assistant Cloud", + ) + ) is None: + return None + + return cloud_pipeline.id diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index 467ce3bcc0b..c937a415cda 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -1,4 +1,6 @@ """The HTTP api to control the cloud integration.""" +from __future__ import annotations + import asyncio from collections.abc import Awaitable, Callable, Coroutine, Mapping from contextlib import suppress @@ -16,7 +18,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED from hass_nabucasa.voice import MAP_VOICE import voluptuous as vol -from homeassistant.components import assist_pipeline, conversation, websocket_api +from homeassistant.components import websocket_api from homeassistant.components.alexa import ( entities as alexa_entities, errors as alexa_errors, @@ -32,6 +34,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.util.location import async_detect_location_info from .alexa_config import entity_supported as entity_supported_by_alexa +from .assist_pipeline import async_create_cloud_pipeline from .client import CloudClient from .const import ( DOMAIN, @@ -210,34 +213,11 @@ class CloudLoginView(HomeAssistantView): ) async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: """Handle login request.""" - - def cloud_assist_pipeline(hass: HomeAssistant) -> str | None: - """Return the ID of a cloud-enabled assist pipeline or None.""" - for pipeline in assist_pipeline.async_get_pipelines(hass): - if ( - pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT - and pipeline.stt_engine == DOMAIN - and pipeline.tts_engine == DOMAIN - ): - return pipeline.id - return None - - hass = request.app["hass"] - cloud = hass.data[DOMAIN] + hass: HomeAssistant = request.app["hass"] + cloud: Cloud[CloudClient] = hass.data[DOMAIN] await cloud.login(data["email"], data["password"]) - # Make sure the pipeline store is loaded, needed because assist_pipeline - # is an after dependency of cloud - await assist_pipeline.async_setup_pipeline_store(hass) - new_cloud_pipeline_id: str | None = None - if (cloud_assist_pipeline(hass)) is None: - if cloud_pipeline := await assist_pipeline.async_create_default_pipeline( - hass, - stt_engine_id=DOMAIN, - tts_engine_id=DOMAIN, - pipeline_name="Home Assistant Cloud", - ): - new_cloud_pipeline_id = cloud_pipeline.id + new_cloud_pipeline_id = await async_create_cloud_pipeline(hass) return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id}) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index 2520c10b4de..a04729faf67 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -150,7 +150,7 @@ async def test_login_view_existing_pipeline( cloud_client = await hass_client() with patch( - "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + "homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline", ) as create_pipeline_mock: req = await cloud_client.post( "/api/cloud/login", json={"email": "my_username", "password": "my_password"} @@ -183,7 +183,7 @@ async def test_login_view_create_pipeline( cloud_client = await hass_client() with patch( - "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + "homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline", return_value=AsyncMock(id="12345"), ) as create_pipeline_mock: req = await cloud_client.post( @@ -222,7 +222,7 @@ async def test_login_view_create_pipeline_fail( cloud_client = await hass_client() with patch( - "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline", + "homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline", return_value=None, ) as create_pipeline_mock: req = await cloud_client.post(