"""Config flow for OctoPrint integration.""" from __future__ import annotations import asyncio from collections.abc import Mapping import logging from typing import TYPE_CHECKING, Any import aiohttp from pyoctoprintapi import ApiError, OctoprintClient, OctoprintException import voluptuous as vol from yarl import URL from homeassistant.components import ssdp, zeroconf from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.const import ( CONF_API_KEY, CONF_HOST, CONF_PATH, CONF_PORT, CONF_SSL, CONF_USERNAME, CONF_VERIFY_SSL, ) from homeassistant.data_entry_flow import AbortFlow from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.util.ssl import get_default_context, get_default_no_verify_context from .const import DOMAIN _LOGGER = logging.getLogger(__name__) def _schema_with_defaults( username="", host="", port=80, path="/", ssl=False, verify_ssl=True ): return vol.Schema( { vol.Required(CONF_USERNAME, default=username): str, vol.Required(CONF_HOST, default=host): str, vol.Required(CONF_PORT, default=port): cv.port, vol.Required(CONF_PATH, default=path): str, vol.Required(CONF_SSL, default=ssl): bool, vol.Required(CONF_VERIFY_SSL, default=verify_ssl): bool, }, extra=vol.ALLOW_EXTRA, ) class OctoPrintConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for OctoPrint.""" VERSION = 1 api_key_task: asyncio.Task[None] | None = None discovery_schema: vol.Schema | None = None _reauth_data: dict[str, Any] | None = None _user_input: dict[str, Any] | None = None def __init__(self) -> None: """Handle a config flow for OctoPrint.""" self._sessions: list[aiohttp.ClientSession] = [] async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle the initial step.""" # When coming back from the progress steps, the user_input is stored in the # instance variable instead of being passed in if user_input is None and self._user_input: user_input = self._user_input if user_input is None: data = self.discovery_schema or _schema_with_defaults() return self.async_show_form(step_id="user", data_schema=data) if CONF_API_KEY in user_input: errors = {} try: return await self._finish_config(user_input) except AbortFlow as err: raise err from None except CannotConnect: errors["base"] = "cannot_connect" except Exception: # noqa: BLE001 errors["base"] = "unknown" if errors: return self.async_show_form( step_id="user", errors=errors, data_schema=_schema_with_defaults( user_input.get(CONF_USERNAME), user_input[CONF_HOST], user_input[CONF_PORT], user_input[CONF_PATH], user_input[CONF_SSL], user_input[CONF_VERIFY_SSL], ), ) self._user_input = user_input return await self.async_step_get_api_key() async def async_step_get_api_key( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Get an Application Api Key.""" if not self.api_key_task: self.api_key_task = self.hass.async_create_task( self._async_get_auth_key(), eager_start=False ) if not self.api_key_task.done(): return self.async_show_progress( step_id="get_api_key", progress_action="get_api_key", progress_task=self.api_key_task, ) try: await self.api_key_task except OctoprintException: _LOGGER.exception("Failed to get an application key") return self.async_show_progress_done(next_step_id="auth_failed") except Exception: _LOGGER.exception("Failed to get an application key") return self.async_show_progress_done(next_step_id="auth_failed") finally: self.api_key_task = None return self.async_show_progress_done(next_step_id="user") async def _finish_config(self, user_input: dict[str, Any]) -> ConfigFlowResult: """Finish the configuration setup.""" existing_entry = await self.async_set_unique_id(self.unique_id) if existing_entry is not None: self.hass.config_entries.async_update_entry(existing_entry, data=user_input) # Reload the config entry otherwise devices will remain unavailable self.hass.async_create_task( self.hass.config_entries.async_reload(existing_entry.entry_id), ) return self.async_abort(reason="reauth_successful") octoprint = self._get_octoprint_client(user_input) octoprint.set_api_key(user_input[CONF_API_KEY]) try: discovery = await octoprint.get_discovery_info() except ApiError as err: _LOGGER.error("Failed to connect to printer") raise CannotConnect from err await self.async_set_unique_id(discovery.upnp_uuid, raise_on_progress=False) self._abort_if_unique_id_configured() return self.async_create_entry(title=user_input[CONF_HOST], data=user_input) async def async_step_auth_failed(self, user_input: None) -> ConfigFlowResult: """Handle api fetch failure.""" return self.async_abort(reason="auth_failed") async def async_step_import(self, import_data: dict[str, Any]) -> ConfigFlowResult: """Handle import.""" return await self.async_step_user(import_data) async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> ConfigFlowResult: """Handle discovery flow.""" uuid = discovery_info.properties["uuid"] await self.async_set_unique_id(uuid) self._abort_if_unique_id_configured() self.context.update( { "title_placeholders": {CONF_HOST: discovery_info.host}, "configuration_url": ( f"http://{discovery_info.host}:{discovery_info.port}" f"{discovery_info.properties[CONF_PATH]}" ), } ) self.discovery_schema = _schema_with_defaults( host=discovery_info.host, port=discovery_info.port, path=discovery_info.properties[CONF_PATH], ) return await self.async_step_user() async def async_step_ssdp( self, discovery_info: ssdp.SsdpServiceInfo ) -> ConfigFlowResult: """Handle ssdp discovery flow.""" uuid = discovery_info.upnp["UDN"][5:] await self.async_set_unique_id(uuid) self._abort_if_unique_id_configured() url = URL(discovery_info.upnp["presentationURL"]) self.context.update( { "title_placeholders": {CONF_HOST: url.host}, "configuration_url": discovery_info.upnp["presentationURL"], } ) self.discovery_schema = _schema_with_defaults( host=url.host, path=url.path, port=url.port, ssl=url.scheme == "https", ) return await self.async_step_user() async def async_step_reauth( self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: """Handle reauthorization request from Octoprint.""" self._reauth_data = dict(entry_data) self.context.update( { "title_placeholders": {CONF_HOST: entry_data[CONF_HOST]}, } ) return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle reauthorization flow.""" assert self._reauth_data is not None if user_input is None: return self.async_show_form( step_id="reauth_confirm", data_schema=vol.Schema( { vol.Required( CONF_USERNAME, default=self._reauth_data[CONF_USERNAME] ): str, } ), ) self._reauth_data[CONF_USERNAME] = user_input[CONF_USERNAME] self._user_input = self._reauth_data return await self.async_step_get_api_key() async def _async_get_auth_key(self) -> None: """Get application api key.""" if TYPE_CHECKING: assert self._user_input is not None octoprint = self._get_octoprint_client(self._user_input) self._user_input[CONF_API_KEY] = await octoprint.request_app_key( "Home Assistant", self._user_input[CONF_USERNAME], 300 ) def _get_octoprint_client(self, user_input: dict[str, Any]) -> OctoprintClient: """Build an octoprint client from the user_input.""" verify_ssl = user_input.get(CONF_VERIFY_SSL, True) connector = aiohttp.TCPConnector( force_close=True, ssl=get_default_no_verify_context() if not verify_ssl else get_default_context(), ) session = aiohttp.ClientSession(connector=connector) self._sessions.append(session) return OctoprintClient( host=user_input[CONF_HOST], session=session, port=user_input[CONF_PORT], ssl=user_input[CONF_SSL], path=user_input[CONF_PATH], ) def async_remove(self) -> None: """Detach the session.""" for session in self._sessions: session.detach() class CannotConnect(HomeAssistantError): """Error to indicate we cannot connect."""