Device Automation: enforce passing in device-automation-enum (#69013)

pull/69042/head
Paulus Schoutsen 2022-03-31 14:30:11 -07:00 committed by GitHub
parent 69ee4cd978
commit 824066f519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 42 deletions

View File

@ -20,7 +20,6 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.frame import report
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import IntegrationNotFound, bind_hass from homeassistant.loader import IntegrationNotFound, bind_hass
from homeassistant.requirements import async_get_integration_with_requirements from homeassistant.requirements import async_get_integration_with_requirements
@ -88,24 +87,6 @@ TYPES = {
} }
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
automation_type: DeviceAutomationType | str,
device_ids: Iterable[str] | None = None,
) -> Mapping[str, Any]:
"""Return all the device automations for a type optionally limited to specific device ids."""
if isinstance(automation_type, str):
report(
"uses str for async_get_device_automations automation_type. This is "
"deprecated and will stop working in Home Assistant 2022.4, it should be "
"updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
return await _async_get_device_automations(hass, automation_type, device_ids)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up device automation.""" """Set up device automation."""
websocket_api.async_register_command(hass, websocket_device_automation_list_actions) websocket_api.async_register_command(hass, websocket_device_automation_list_actions)
@ -156,26 +137,18 @@ async def async_get_device_automation_platform( # noqa: D103
@overload @overload
async def async_get_device_automation_platform( # noqa: D103 async def async_get_device_automation_platform( # noqa: D103
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType
) -> "DeviceAutomationPlatformType": ) -> "DeviceAutomationPlatformType":
... ...
async def async_get_device_automation_platform( async def async_get_device_automation_platform(
hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType | str hass: HomeAssistant, domain: str, automation_type: DeviceAutomationType
) -> "DeviceAutomationPlatformType": ) -> "DeviceAutomationPlatformType":
"""Load device automation platform for integration. """Load device automation platform for integration.
Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation. Throws InvalidDeviceAutomationConfig if the integration is not found or does not support device automation.
""" """
if isinstance(automation_type, str):
report(
"uses str for async_get_device_automation_platform automation_type. This "
"is deprecated and will stop working in Home Assistant 2022.4, it should "
"be updated to use DeviceAutomationType instead",
error_if_core=False,
)
automation_type = DeviceAutomationType[automation_type.upper()]
platform_name = automation_type.value.section platform_name = automation_type.value.section
try: try:
integration = await async_get_integration_with_requirements(hass, domain) integration = await async_get_integration_with_requirements(hass, domain)
@ -215,10 +188,11 @@ async def _async_get_device_automations_from_domain(
) )
async def _async_get_device_automations( @bind_hass
async def async_get_device_automations(
hass: HomeAssistant, hass: HomeAssistant,
automation_type: DeviceAutomationType, automation_type: DeviceAutomationType,
device_ids: Iterable[str] | None, device_ids: Iterable[str] | None = None,
) -> Mapping[str, list[dict[str, Any]]]: ) -> Mapping[str, list[dict[str, Any]]]:
"""List device automations.""" """List device automations."""
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
@ -336,7 +310,7 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions.""" """Handle request for device actions."""
device_id = msg["device_id"] device_id = msg["device_id"]
actions = ( actions = (
await _async_get_device_automations( await async_get_device_automations(
hass, DeviceAutomationType.ACTION, [device_id] hass, DeviceAutomationType.ACTION, [device_id]
) )
).get(device_id) ).get(device_id)
@ -355,7 +329,7 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions.""" """Handle request for device conditions."""
device_id = msg["device_id"] device_id = msg["device_id"]
conditions = ( conditions = (
await _async_get_device_automations( await async_get_device_automations(
hass, DeviceAutomationType.CONDITION, [device_id] hass, DeviceAutomationType.CONDITION, [device_id]
) )
).get(device_id) ).get(device_id)
@ -374,7 +348,7 @@ async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers.""" """Handle request for device triggers."""
device_id = msg["device_id"] device_id = msg["device_id"]
triggers = ( triggers = (
await _async_get_device_automations( await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, [device_id] hass, DeviceAutomationType.TRIGGER, [device_id]
) )
).get(device_id) ).get(device_id)

View File

@ -404,13 +404,6 @@ async def test_async_get_device_automations_single_device_trigger(
assert device_entry.id in result assert device_entry.id in result
assert len(result[device_entry.id]) == 3 assert len(result[device_entry.id]) == 3
# Test deprecated str automation_type works, to be removed in 2022.4
result = await device_automation.async_get_device_automations(
hass, "trigger", [device_entry.id]
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 3 # toggled, turned_on, turned_off
async def test_async_get_device_automations_all_devices_trigger( async def test_async_get_device_automations_all_devices_trigger(
hass, device_reg, entity_reg hass, device_reg, entity_reg

View File

@ -2,6 +2,7 @@
import pytest import pytest
from homeassistant.components import automation from homeassistant.components import automation
from homeassistant.components.device_automation import DeviceAutomationType
from homeassistant.components.device_automation.exceptions import ( from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig, InvalidDeviceAutomationConfig,
) )
@ -31,7 +32,9 @@ async def test_get_triggers(hass, client):
"device_id": device.id, "device_id": device.id,
} }
triggers = await async_get_device_automations(hass, "trigger", device.id) triggers = await async_get_device_automations(
hass, DeviceAutomationType.TRIGGER, device.id
)
assert turn_on_trigger in triggers assert turn_on_trigger in triggers