Refactor cloud assist pipeline (#105723)

* Refactor cloud assist pipeline

* Return None early
pull/105784/head
Martin Hjelmare 2023-12-14 23:56:08 +01:00 committed by GitHub
parent f4c8920231
commit 4da04a358a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 30 deletions

View File

@ -0,0 +1,44 @@
"""Handle Cloud assist pipelines."""
from homeassistant.components.assist_pipeline import (
async_create_default_pipeline,
async_get_pipelines,
async_setup_pipeline_store,
)
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
from homeassistant.core import HomeAssistant
from .const import DOMAIN
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
"""Create a cloud assist pipeline."""
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None.
Check if a cloud pipeline already exists with
legacy cloud engine id.
"""
for pipeline in async_get_pipelines(hass):
if (
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None
if (cloud_assist_pipeline(hass)) is not None or (
cloud_pipeline := await async_create_default_pipeline(
hass,
stt_engine_id=DOMAIN,
tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud",
)
) is None:
return None
return cloud_pipeline.id

View File

@ -1,4 +1,6 @@
"""The HTTP api to control the cloud integration."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine, Mapping
from contextlib import suppress
@ -16,7 +18,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import MAP_VOICE
import voluptuous as vol
from homeassistant.components import assist_pipeline, conversation, websocket_api
from homeassistant.components import websocket_api
from homeassistant.components.alexa import (
entities as alexa_entities,
errors as alexa_errors,
@ -32,6 +34,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util.location import async_detect_location_info
from .alexa_config import entity_supported as entity_supported_by_alexa
from .assist_pipeline import async_create_cloud_pipeline
from .client import CloudClient
from .const import (
DOMAIN,
@ -210,34 +213,11 @@ class CloudLoginView(HomeAssistantView):
)
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Handle login request."""
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None
hass = request.app["hass"]
cloud = hass.data[DOMAIN]
hass: HomeAssistant = request.app["hass"]
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"])
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await assist_pipeline.async_setup_pipeline_store(hass)
new_cloud_pipeline_id: str | None = None
if (cloud_assist_pipeline(hass)) is None:
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
hass,
stt_engine_id=DOMAIN,
tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud",
):
new_cloud_pipeline_id = cloud_pipeline.id
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})

View File

@ -150,7 +150,7 @@ async def test_login_view_existing_pipeline(
cloud_client = await hass_client()
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
) as create_pipeline_mock:
req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
@ -183,7 +183,7 @@ async def test_login_view_create_pipeline(
cloud_client = await hass_client()
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
return_value=AsyncMock(id="12345"),
) as create_pipeline_mock:
req = await cloud_client.post(
@ -222,7 +222,7 @@ async def test_login_view_create_pipeline_fail(
cloud_client = await hass_client()
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
return_value=None,
) as create_pipeline_mock:
req = await cloud_client.post(