Refactor cloud assist pipeline (#105723)
* Refactor cloud assist pipeline * Return None earlypull/105784/head
parent
f4c8920231
commit
4da04a358a
|
@ -0,0 +1,44 @@
|
|||
"""Handle Cloud assist pipelines."""
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
async_create_default_pipeline,
|
||||
async_get_pipelines,
|
||||
async_setup_pipeline_store,
|
||||
)
|
||||
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Create a cloud assist pipeline."""
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
# is an after dependency of cloud
|
||||
await async_setup_pipeline_store(hass)
|
||||
|
||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Return the ID of a cloud-enabled assist pipeline or None.
|
||||
|
||||
Check if a cloud pipeline already exists with
|
||||
legacy cloud engine id.
|
||||
"""
|
||||
for pipeline in async_get_pipelines(hass):
|
||||
if (
|
||||
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
||||
and pipeline.stt_engine == DOMAIN
|
||||
and pipeline.tts_engine == DOMAIN
|
||||
):
|
||||
return pipeline.id
|
||||
return None
|
||||
|
||||
if (cloud_assist_pipeline(hass)) is not None or (
|
||||
cloud_pipeline := await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id=DOMAIN,
|
||||
tts_engine_id=DOMAIN,
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
) is None:
|
||||
return None
|
||||
|
||||
return cloud_pipeline.id
|
|
@ -1,4 +1,6 @@
|
|||
"""The HTTP api to control the cloud integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable, Coroutine, Mapping
|
||||
from contextlib import suppress
|
||||
|
@ -16,7 +18,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
|
|||
from hass_nabucasa.voice import MAP_VOICE
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation, websocket_api
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.alexa import (
|
||||
entities as alexa_entities,
|
||||
errors as alexa_errors,
|
||||
|
@ -32,6 +34,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|||
from homeassistant.util.location import async_detect_location_info
|
||||
|
||||
from .alexa_config import entity_supported as entity_supported_by_alexa
|
||||
from .assist_pipeline import async_create_cloud_pipeline
|
||||
from .client import CloudClient
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
|
@ -210,34 +213,11 @@ class CloudLoginView(HomeAssistantView):
|
|||
)
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Handle login request."""
|
||||
|
||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||
"""Return the ID of a cloud-enabled assist pipeline or None."""
|
||||
for pipeline in assist_pipeline.async_get_pipelines(hass):
|
||||
if (
|
||||
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
|
||||
and pipeline.stt_engine == DOMAIN
|
||||
and pipeline.tts_engine == DOMAIN
|
||||
):
|
||||
return pipeline.id
|
||||
return None
|
||||
|
||||
hass = request.app["hass"]
|
||||
cloud = hass.data[DOMAIN]
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||
await cloud.login(data["email"], data["password"])
|
||||
|
||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||
# is an after dependency of cloud
|
||||
await assist_pipeline.async_setup_pipeline_store(hass)
|
||||
new_cloud_pipeline_id: str | None = None
|
||||
if (cloud_assist_pipeline(hass)) is None:
|
||||
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id=DOMAIN,
|
||||
tts_engine_id=DOMAIN,
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
):
|
||||
new_cloud_pipeline_id = cloud_pipeline.id
|
||||
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
||||
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
|
||||
|
||||
|
||||
|
|
|
@ -150,7 +150,7 @@ async def test_login_view_existing_pipeline(
|
|||
cloud_client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
||||
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||
) as create_pipeline_mock:
|
||||
req = await cloud_client.post(
|
||||
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||
|
@ -183,7 +183,7 @@ async def test_login_view_create_pipeline(
|
|||
cloud_client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
||||
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||
return_value=AsyncMock(id="12345"),
|
||||
) as create_pipeline_mock:
|
||||
req = await cloud_client.post(
|
||||
|
@ -222,7 +222,7 @@ async def test_login_view_create_pipeline_fail(
|
|||
cloud_client = await hass_client()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
||||
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||
return_value=None,
|
||||
) as create_pipeline_mock:
|
||||
req = await cloud_client.post(
|
||||
|
|
Loading…
Reference in New Issue