Refactor Rest Switch with ManualTriggerEntity (#97403)

* Refactor Rest Switch with ManualTriggerEntity

* Fix test

* Fix 2

* review comments

* remove async_added_to_hass

* update on startup
pull/98398/head^2
G Johansson 2023-08-15 11:43:47 +02:00 committed by GitHub
parent 87b7fc6c61
commit ed18c6a013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 20 deletions

View File

@ -18,7 +18,9 @@ from homeassistant.components.switch import (
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_HEADERS,
CONF_ICON,
CONF_METHOD,
CONF_NAME,
CONF_PARAMS,
CONF_PASSWORD,
CONF_RESOURCE,
@ -33,8 +35,10 @@ from homeassistant.helpers import config_validation as cv, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.template_entity import (
CONF_AVAILABILITY,
CONF_PICTURE,
TEMPLATE_ENTITY_BASE_SCHEMA,
TemplateEntity,
ManualTriggerEntity,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
@ -44,6 +48,14 @@ CONF_BODY_ON = "body_on"
CONF_IS_ON_TEMPLATE = "is_on_template"
CONF_STATE_RESOURCE = "state_resource"
TRIGGER_ENTITY_OPTIONS = (
CONF_AVAILABILITY,
CONF_DEVICE_CLASS,
CONF_ICON,
CONF_PICTURE,
CONF_UNIQUE_ID,
)
DEFAULT_METHOD = "post"
DEFAULT_BODY_OFF = "OFF"
DEFAULT_BODY_ON = "ON"
@ -71,6 +83,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Inclusive(CONF_USERNAME, "authentication"): cv.string,
vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string,
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
vol.Optional(CONF_AVAILABILITY): cv.template,
}
)
@ -83,10 +96,17 @@ async def async_setup_platform(
) -> None:
"""Set up the RESTful switch."""
resource: str = config[CONF_RESOURCE]
unique_id: str | None = config.get(CONF_UNIQUE_ID)
name = config.get(CONF_NAME) or template.Template(DEFAULT_NAME, hass)
trigger_entity_config = {CONF_NAME: name}
for key in TRIGGER_ENTITY_OPTIONS:
if key not in config:
continue
trigger_entity_config[key] = config[key]
try:
switch = RestSwitch(hass, config, unique_id)
switch = RestSwitch(hass, config, trigger_entity_config)
req = await switch.get_device_state(hass)
if req.status_code >= HTTPStatus.BAD_REQUEST:
@ -102,23 +122,17 @@ async def async_setup_platform(
raise PlatformNotReady(f"No route to resource/endpoint: {resource}") from exc
class RestSwitch(TemplateEntity, SwitchEntity):
class RestSwitch(ManualTriggerEntity, SwitchEntity):
"""Representation of a switch that can be toggled using REST."""
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
unique_id: str | None,
trigger_entity_config: ConfigType,
) -> None:
"""Initialize the REST switch."""
TemplateEntity.__init__(
self,
hass,
config=config,
fallback_name=DEFAULT_NAME,
unique_id=unique_id,
)
ManualTriggerEntity.__init__(self, hass, trigger_entity_config)
auth: httpx.BasicAuth | None = None
username: str | None = None
@ -138,8 +152,6 @@ class RestSwitch(TemplateEntity, SwitchEntity):
self._timeout: int = config[CONF_TIMEOUT]
self._verify_ssl: bool = config[CONF_VERIFY_SSL]
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._body_on.hass = hass
self._body_off.hass = hass
if (is_on_template := self._is_on_template) is not None:
@ -148,6 +160,11 @@ class RestSwitch(TemplateEntity, SwitchEntity):
template.attach(hass, self._headers)
template.attach(hass, self._params)
async def async_added_to_hass(self) -> None:
"""Handle adding to Home Assistant."""
await super().async_added_to_hass()
await self.async_update()
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the device on."""
body_on_t = self._body_on.async_render(parse_result=False)
@ -198,13 +215,18 @@ class RestSwitch(TemplateEntity, SwitchEntity):
async def async_update(self) -> None:
"""Get the current state, catching errors."""
req = None
try:
await self.get_device_state(self.hass)
req = await self.get_device_state(self.hass)
except asyncio.TimeoutError:
_LOGGER.exception("Timed out while fetching data")
except httpx.RequestError as err:
_LOGGER.exception("Error while fetching data: %s", err)
if req:
self._process_manual_data(req.text)
self.async_write_ha_state()
async def get_device_state(self, hass: HomeAssistant) -> httpx.Response:
"""Get the latest data from REST API and update the state."""
websession = get_async_client(hass, self._verify_ssl)

View File

@ -111,7 +111,7 @@ async def test_setup_minimum(hass: HomeAssistant) -> None:
with assert_setup_component(1, SWITCH_DOMAIN):
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
await hass.async_block_till_done()
assert route.call_count == 1
assert route.call_count == 2
@respx.mock
@ -129,7 +129,7 @@ async def test_setup_query_params(hass: HomeAssistant) -> None:
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
await hass.async_block_till_done()
assert route.call_count == 1
assert route.call_count == 2
@respx.mock
@ -148,7 +148,7 @@ async def test_setup(hass: HomeAssistant) -> None:
}
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
await hass.async_block_till_done()
assert route.call_count == 1
assert route.call_count == 2
assert_setup_component(1, SWITCH_DOMAIN)
@ -170,7 +170,7 @@ async def test_setup_with_state_resource(hass: HomeAssistant) -> None:
}
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
await hass.async_block_till_done()
assert route.call_count == 1
assert route.call_count == 2
assert_setup_component(1, SWITCH_DOMAIN)
@ -195,7 +195,7 @@ async def test_setup_with_templated_headers_params(hass: HomeAssistant) -> None:
}
assert await async_setup_component(hass, SWITCH_DOMAIN, config)
await hass.async_block_till_done()
assert route.call_count == 1
assert route.call_count == 2
last_call = route.calls[-1]
last_request: httpx.Request = last_call.request
assert last_request.headers.get("Accept") == CONTENT_TYPE_JSON