diff --git a/homeassistant/components/easyenergy/__init__.py b/homeassistant/components/easyenergy/__init__.py index 6c00ec5a6a3..e941c78b1fb 100644 --- a/homeassistant/components/easyenergy/__init__.py +++ b/homeassistant/components/easyenergy/__init__.py @@ -5,12 +5,23 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.typing import ConfigType from .const import DOMAIN from .coordinator import EasyEnergyDataUpdateCoordinator from .services import async_setup_services PLATFORMS = [Platform.SENSOR] +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up the easyEnergy services.""" + + async_setup_services(hass) + + return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: @@ -27,8 +38,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - async_setup_services(hass, coordinator) - return True diff --git a/homeassistant/components/easyenergy/services.py b/homeassistant/components/easyenergy/services.py index 777fa4280b2..a68dfcb791c 100644 --- a/homeassistant/components/easyenergy/services.py +++ b/homeassistant/components/easyenergy/services.py @@ -9,6 +9,7 @@ from typing import Final from easyenergy import Electricity, Gas, VatOption import voluptuous as vol +from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.core import ( HomeAssistant, ServiceCall, @@ -17,11 +18,13 @@ from homeassistant.core import ( callback, ) from homeassistant.exceptions import ServiceValidationError +from homeassistant.helpers import selector from homeassistant.util import dt as dt_util from .const import DOMAIN from .coordinator import EasyEnergyDataUpdateCoordinator +ATTR_CONFIG_ENTRY: Final = "config_entry" ATTR_START: Final = "start" ATTR_END: Final = "end" ATTR_INCL_VAT: Final = "incl_vat" @@ -31,6 +34,11 @@ ENERGY_USAGE_SERVICE_NAME: Final = "get_energy_usage_prices" ENERGY_RETURN_SERVICE_NAME: Final = "get_energy_return_prices" SERVICE_SCHEMA: Final = vol.Schema( { + vol.Required(ATTR_CONFIG_ENTRY): selector.ConfigEntrySelector( + { + "integration": DOMAIN, + } + ), vol.Required(ATTR_INCL_VAT): bool, vol.Optional(ATTR_START): str, vol.Optional(ATTR_END): str, @@ -77,13 +85,44 @@ def __serialize_prices(prices: list[dict[str, float | datetime]]) -> ServiceResp } +def __get_coordinator( + hass: HomeAssistant, call: ServiceCall +) -> EasyEnergyDataUpdateCoordinator: + """Get the coordinator from the entry.""" + entry_id: str = call.data[ATTR_CONFIG_ENTRY] + entry: ConfigEntry | None = hass.config_entries.async_get_entry(entry_id) + + if not entry: + raise ServiceValidationError( + f"Invalid config entry: {entry_id}", + translation_domain=DOMAIN, + translation_key="invalid_config_entry", + translation_placeholders={ + "config_entry": entry_id, + }, + ) + if entry.state != ConfigEntryState.LOADED: + raise ServiceValidationError( + f"{entry.title} is not loaded", + translation_domain=DOMAIN, + translation_key="unloaded_config_entry", + translation_placeholders={ + "config_entry": entry.title, + }, + ) + + return hass.data[DOMAIN][entry_id] + + async def __get_prices( call: ServiceCall, *, - coordinator: EasyEnergyDataUpdateCoordinator, + hass: HomeAssistant, price_type: PriceType, ) -> ServiceResponse: """Get prices from easyEnergy.""" + coordinator = __get_coordinator(hass, call) + start = __get_date(call.data.get(ATTR_START)) end = __get_date(call.data.get(ATTR_END)) @@ -112,34 +151,27 @@ async def __get_prices( @callback -def async_setup_services( - hass: HomeAssistant, - coordinator: EasyEnergyDataUpdateCoordinator, -) -> None: +def async_setup_services(hass: HomeAssistant) -> None: """Set up services for easyEnergy integration.""" hass.services.async_register( DOMAIN, GAS_SERVICE_NAME, - partial(__get_prices, coordinator=coordinator, price_type=PriceType.GAS), + partial(__get_prices, hass=hass, price_type=PriceType.GAS), schema=SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) hass.services.async_register( DOMAIN, ENERGY_USAGE_SERVICE_NAME, - partial( - __get_prices, coordinator=coordinator, price_type=PriceType.ENERGY_USAGE - ), + partial(__get_prices, hass=hass, price_type=PriceType.ENERGY_USAGE), schema=SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) hass.services.async_register( DOMAIN, ENERGY_RETURN_SERVICE_NAME, - partial( - __get_prices, coordinator=coordinator, price_type=PriceType.ENERGY_RETURN - ), + partial(__get_prices, hass=hass, price_type=PriceType.ENERGY_RETURN), schema=SERVICE_SCHEMA, supports_response=SupportsResponse.ONLY, ) diff --git a/homeassistant/components/easyenergy/services.yaml b/homeassistant/components/easyenergy/services.yaml index 01b78431afb..63187256f00 100644 --- a/homeassistant/components/easyenergy/services.yaml +++ b/homeassistant/components/easyenergy/services.yaml @@ -1,5 +1,10 @@ get_gas_prices: fields: + config_entry: + required: true + selector: + config_entry: + integration: easyenergy incl_vat: required: true default: true @@ -17,6 +22,11 @@ get_gas_prices: datetime: get_energy_usage_prices: fields: + config_entry: + required: true + selector: + config_entry: + integration: easyenergy incl_vat: required: true default: true @@ -34,6 +44,11 @@ get_energy_usage_prices: datetime: get_energy_return_prices: fields: + config_entry: + required: true + selector: + config_entry: + integration: easyenergy start: required: false example: "2024-01-01 00:00:00" diff --git a/homeassistant/components/easyenergy/strings.json b/homeassistant/components/easyenergy/strings.json index 56d793818cb..c42ef9df5ac 100644 --- a/homeassistant/components/easyenergy/strings.json +++ b/homeassistant/components/easyenergy/strings.json @@ -12,6 +12,12 @@ "exceptions": { "invalid_date": { "message": "Invalid date provided. Got {date}" + }, + "invalid_config_entry": { + "message": "Invalid config entry provided. Got {config_entry}" + }, + "unloaded_config_entry": { + "message": "Invalid config entry provided. {config_entry} is not loaded." } }, "entity": { @@ -53,6 +59,10 @@ "name": "Get gas prices", "description": "Request gas prices from easyEnergy.", "fields": { + "config_entry": { + "name": "Config Entry", + "description": "The config entry to use for this service." + }, "incl_vat": { "name": "VAT Included", "description": "Include or exclude VAT in the prices, default is true." @@ -71,6 +81,10 @@ "name": "Get energy usage prices", "description": "Request usage energy prices from easyEnergy.", "fields": { + "config_entry": { + "name": "[%key:component::easyenergy::services::get_gas_prices::fields::config_entry::name%]", + "description": "[%key:component::easyenergy::services::get_gas_prices::fields::config_entry::description%]" + }, "incl_vat": { "name": "[%key:component::easyenergy::services::get_gas_prices::fields::incl_vat::name%]", "description": "[%key:component::easyenergy::services::get_gas_prices::fields::incl_vat::description%]" @@ -89,6 +103,10 @@ "name": "Get energy return prices", "description": "Request return energy prices from easyEnergy.", "fields": { + "config_entry": { + "name": "[%key:component::easyenergy::services::get_gas_prices::fields::config_entry::name%]", + "description": "[%key:component::easyenergy::services::get_gas_prices::fields::config_entry::description%]" + }, "start": { "name": "[%key:component::easyenergy::services::get_gas_prices::fields::start::name%]", "description": "[%key:component::easyenergy::services::get_gas_prices::fields::start::description%]" diff --git a/tests/components/easyenergy/test_services.py b/tests/components/easyenergy/test_services.py index d47b86e93a3..603768237f1 100644 --- a/tests/components/easyenergy/test_services.py +++ b/tests/components/easyenergy/test_services.py @@ -6,6 +6,7 @@ import voluptuous as vol from homeassistant.components.easyenergy.const import DOMAIN from homeassistant.components.easyenergy.services import ( + ATTR_CONFIG_ENTRY, ENERGY_RETURN_SERVICE_NAME, ENERGY_USAGE_SERVICE_NAME, GAS_SERVICE_NAME, @@ -13,6 +14,8 @@ from homeassistant.components.easyenergy.services import ( from homeassistant.core import HomeAssistant from homeassistant.exceptions import ServiceValidationError +from tests.common import MockConfigEntry + @pytest.mark.usefixtures("init_integration") async def test_has_services( @@ -38,6 +41,7 @@ async def test_has_services( @pytest.mark.parametrize("end", [{"end": "2023-01-01 00:00:00"}, {}]) async def test_service( hass: HomeAssistant, + mock_config_entry: MockConfigEntry, snapshot: SnapshotAssertion, service: str, incl_vat: dict[str, bool], @@ -45,8 +49,9 @@ async def test_service( end: dict[str, str], ) -> None: """Test the EnergyZero Service.""" + entry = {ATTR_CONFIG_ENTRY: mock_config_entry.entry_id} - data = incl_vat | start | end + data = entry | incl_vat | start | end assert snapshot == await hass.services.async_call( DOMAIN, @@ -57,6 +62,17 @@ async def test_service( ) +@pytest.fixture +def config_entry_data( + mock_config_entry: MockConfigEntry, request: pytest.FixtureRequest +) -> dict[str, str]: + """Fixture for the config entry.""" + if "config_entry" in request.param and request.param["config_entry"] is True: + return {"config_entry": mock_config_entry.entry_id} + + return request.param + + @pytest.mark.usefixtures("init_integration") @pytest.mark.parametrize( "service", @@ -67,29 +83,58 @@ async def test_service( ], ) @pytest.mark.parametrize( - ("service_data", "error", "error_message"), + ("config_entry_data", "service_data", "error", "error_message"), [ - ({}, vol.er.Error, "required key not provided .+"), + ({}, {}, vol.er.Error, "required key not provided .+"), ( + {"config_entry": True}, + {}, + vol.er.Error, + "required key not provided .+", + ), + ( + {}, + {"incl_vat": True}, + vol.er.Error, + "required key not provided .+", + ), + ( + {"config_entry": True}, {"incl_vat": "incorrect vat"}, vol.er.Error, "expected bool for dictionary value .+", ), ( - {"incl_vat": True, "start": "incorrect date"}, + {"config_entry": "incorrect entry"}, + {"incl_vat": True}, + ServiceValidationError, + "Invalid config entry.+", + ), + ( + {"config_entry": True}, + { + "incl_vat": True, + "start": "incorrect date", + }, ServiceValidationError, "Invalid datetime provided.", ), ( - {"incl_vat": True, "end": "incorrect date"}, + {"config_entry": True}, + { + "incl_vat": True, + "end": "incorrect date", + }, ServiceValidationError, "Invalid datetime provided.", ), ], + indirect=["config_entry_data"], ) async def test_service_validation( hass: HomeAssistant, service: str, + config_entry_data: dict[str, str], service_data: dict[str, str | bool], error: type[Exception], error_message: str, @@ -100,7 +145,7 @@ async def test_service_validation( await hass.services.async_call( DOMAIN, service, - service_data, + config_entry_data | service_data, blocking=True, return_response=True, )