diff --git a/homeassistant/components/octoprint/__init__.py b/homeassistant/components/octoprint/__init__.py index 50ba6c964f3..1a96078c003 100644 --- a/homeassistant/components/octoprint/__init__.py +++ b/homeassistant/components/octoprint/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +from typing import cast import aiohttp from pyoctoprintapi import OctoprintClient @@ -11,24 +12,28 @@ from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import ( CONF_API_KEY, CONF_BINARY_SENSORS, + CONF_DEVICE_ID, CONF_HOST, CONF_MONITORED_CONDITIONS, CONF_NAME, CONF_PATH, CONF_PORT, + CONF_PROFILE_NAME, CONF_SENSORS, CONF_SSL, CONF_VERIFY_SSL, EVENT_HOMEASSISTANT_STOP, Platform, ) -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import Event, HomeAssistant, ServiceCall, callback +from homeassistant.exceptions import ServiceValidationError import homeassistant.helpers.config_validation as cv +import homeassistant.helpers.device_registry as dr from homeassistant.helpers.typing import ConfigType from homeassistant.util import slugify as util_slugify from homeassistant.util.ssl import get_default_context, get_default_no_verify_context -from .const import DOMAIN +from .const import CONF_BAUDRATE, DOMAIN, SERVICE_CONNECT from .coordinator import OctoprintDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) @@ -122,6 +127,15 @@ CONFIG_SCHEMA = vol.Schema( extra=vol.ALLOW_EXTRA, ) +SERVICE_CONNECT_SCHEMA = vol.Schema( + { + vol.Required(CONF_DEVICE_ID): cv.string, + vol.Optional(CONF_PROFILE_NAME): cv.string, + vol.Optional(CONF_PORT): cv.string, + vol.Optional(CONF_BAUDRATE): cv.positive_int, + } +) + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the OctoPrint component.""" @@ -194,6 +208,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + async def async_printer_connect(call: ServiceCall) -> None: + """Connect to a printer.""" + client = async_get_client_for_service_call(hass, call) + await client.connect( + printer_profile=call.data.get(CONF_PROFILE_NAME), + port=call.data.get(CONF_PORT), + baud_rate=call.data.get(CONF_BAUDRATE), + ) + + if not hass.services.has_service(DOMAIN, SERVICE_CONNECT): + hass.services.async_register( + DOMAIN, + SERVICE_CONNECT, + async_printer_connect, + schema=SERVICE_CONNECT_SCHEMA, + ) + return True @@ -205,3 +236,24 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: hass.data[DOMAIN].pop(entry.entry_id) return unload_ok + + +def async_get_client_for_service_call( + hass: HomeAssistant, call: ServiceCall +) -> OctoprintClient: + """Get the client related to a service call (by device ID).""" + device_id = call.data[CONF_DEVICE_ID] + device_registry = dr.async_get(hass) + + if device_entry := device_registry.async_get(device_id): + for entry_id in device_entry.config_entries: + if data := hass.data[DOMAIN].get(entry_id): + return cast(OctoprintClient, data["client"]) + + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="missing_client", + translation_placeholders={ + "device_id": device_id, + }, + ) diff --git a/homeassistant/components/octoprint/const.py b/homeassistant/components/octoprint/const.py index df22cb8d8f8..2d2a9e4a907 100644 --- a/homeassistant/components/octoprint/const.py +++ b/homeassistant/components/octoprint/const.py @@ -3,3 +3,6 @@ DOMAIN = "octoprint" DEFAULT_NAME = "OctoPrint" + +SERVICE_CONNECT = "printer_connect" +CONF_BAUDRATE = "baudrate" diff --git a/homeassistant/components/octoprint/services.yaml b/homeassistant/components/octoprint/services.yaml new file mode 100644 index 00000000000..2cb4a6f3c2d --- /dev/null +++ b/homeassistant/components/octoprint/services.yaml @@ -0,0 +1,27 @@ +printer_connect: + fields: + device_id: + required: true + selector: + device: + integration: octoprint + profile_name: + required: false + selector: + text: + port: + required: false + selector: + text: + baudrate: + required: false + selector: + select: + options: + - "9600" + - "19200" + - "38400" + - "57600" + - "115200" + - "230400" + - "250000" diff --git a/homeassistant/components/octoprint/strings.json b/homeassistant/components/octoprint/strings.json index 63d9753ee1d..e9df0ed755c 100644 --- a/homeassistant/components/octoprint/strings.json +++ b/homeassistant/components/octoprint/strings.json @@ -35,5 +35,34 @@ "progress": { "get_api_key": "Open the OctoPrint UI and click 'Allow' on the Access Request for 'Home Assistant'." } + }, + "exceptions": { + "missing_client": { + "message": "No client for device ID: {device_id}" + } + }, + "services": { + "printer_connect": { + "name": "Connect to a printer", + "description": "Instructs the octoprint server to connect to a printer.", + "fields": { + "device_id": { + "name": "Server", + "description": "The server that should connect." + }, + "profile_name": { + "name": "Profile name", + "description": "Printer profile to connect with." + }, + "port": { + "name": "Serial port", + "description": "Port name to connect on." + }, + "baudrate": { + "name": "Baudrate", + "description": "Baud rate." + } + } + } } } diff --git a/tests/components/octoprint/test_servics.py b/tests/components/octoprint/test_servics.py new file mode 100644 index 00000000000..70e983c4bb4 --- /dev/null +++ b/tests/components/octoprint/test_servics.py @@ -0,0 +1,66 @@ +"""Test the OctoPrint services.""" +from unittest.mock import patch + +from homeassistant.components.octoprint.const import ( + CONF_BAUDRATE, + DOMAIN, + SERVICE_CONNECT, +) +from homeassistant.const import ATTR_DEVICE_ID, CONF_PORT, CONF_PROFILE_NAME +from homeassistant.helpers.device_registry import ( + async_entries_for_config_entry, + async_get as async_get_dev_reg, +) + +from . import init_integration + + +async def test_connect_default(hass) -> None: + """Test the connect to printer service.""" + await init_integration(hass, "sensor") + + dev_reg = async_get_dev_reg(hass) + device = async_entries_for_config_entry(dev_reg, "uuid")[0] + + # Test pausing the printer when it is printing + with patch("pyoctoprintapi.OctoprintClient.connect") as connect_command: + await hass.services.async_call( + DOMAIN, + SERVICE_CONNECT, + { + ATTR_DEVICE_ID: device.id, + }, + blocking=True, + ) + + assert len(connect_command.mock_calls) == 1 + connect_command.assert_called_with( + port=None, printer_profile=None, baud_rate=None + ) + + +async def test_connect_all_arguments(hass) -> None: + """Test the connect to printer service.""" + await init_integration(hass, "sensor") + + dev_reg = async_get_dev_reg(hass) + device = async_entries_for_config_entry(dev_reg, "uuid")[0] + + # Test pausing the printer when it is printing + with patch("pyoctoprintapi.OctoprintClient.connect") as connect_command: + await hass.services.async_call( + DOMAIN, + SERVICE_CONNECT, + { + ATTR_DEVICE_ID: device.id, + CONF_PROFILE_NAME: "Test Profile", + CONF_PORT: "VIRTUAL", + CONF_BAUDRATE: 9600, + }, + blocking=True, + ) + + assert len(connect_command.mock_calls) == 1 + connect_command.assert_called_with( + port="VIRTUAL", printer_profile="Test Profile", baud_rate=9600 + )