Improve MQTT type hints part 1 (#80523)

* Improve typing alarm_control_panel

* Improve typing binary_sensor

* Improve typing button

* Add misssed annotation

* Move CONF_EXPIRE_AFTER to _setup_from_config

* Use CALL_BACK type

* Remove assert, improve code style
pull/76999/head
Jan Bouwhuis 2022-11-02 20:33:18 +01:00 committed by GitHub
parent 76819d81be
commit b4ad03784f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 61 deletions

View File

@ -48,7 +48,7 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttCommandTemplate, MqttValueTemplate from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -155,8 +155,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT Alarm Control Panel platform.""" """Set up the MQTT Alarm Control Panel platform."""
async_add_entities([MqttAlarm(hass, config, config_entry, discovery_data)]) async_add_entities([MqttAlarm(hass, config, config_entry, discovery_data)])
@ -168,32 +168,39 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
_entity_id_format = alarm.ENTITY_ID_FORMAT _entity_id_format = alarm.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_ALARM_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_ALARM_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Init the MQTT Alarm Control Panel.""" """Init the MQTT Alarm Control Panel."""
self._state: str | None = None self._state: str | None = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity."""
self._value_template = MqttValueTemplate( self._value_template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), config.get(CONF_VALUE_TEMPLATE),
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
self._command_template = MqttCommandTemplate( self._command_template = MqttCommandTemplate(
self._config[CONF_COMMAND_TEMPLATE], entity=self config[CONF_COMMAND_TEMPLATE], entity=self
).async_render ).async_render
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg): def message_received(msg: ReceiveMessage) -> None:
"""Run when new MQTT message has been received.""" """Run when new MQTT message has been received."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if payload not in ( if payload not in (
@ -210,7 +217,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
): ):
_LOGGER.warning("Received unexpected payload: %s", msg.payload) _LOGGER.warning("Received unexpected payload: %s", msg.payload)
return return
self._state = payload self._state = str(payload)
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
@ -226,7 +233,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@ -250,6 +257,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
@property @property
def code_format(self) -> alarm.CodeFormat | None: def code_format(self) -> alarm.CodeFormat | None:
"""Return one or more digits/characters.""" """Return one or more digits/characters."""
code: str | None
if (code := self._config.get(CONF_CODE)) is None: if (code := self._config.get(CONF_CODE)) is None:
return None return None
if code == REMOTE_CODE or (isinstance(code, str) and re.search("^\\d+$", code)): if code == REMOTE_CODE or (isinstance(code, str) and re.search("^\\d+$", code)):
@ -266,10 +274,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_DISARM_REQUIRED] code_required: bool = self._config[CONF_CODE_DISARM_REQUIRED]
if code_required and not self._validate_code(code, "disarming"): if code_required and not self._validate_code(code, "disarming"):
return return
payload = self._config[CONF_PAYLOAD_DISARM] payload: str = self._config[CONF_PAYLOAD_DISARM]
await self._publish(code, payload) await self._publish(code, payload)
async def async_alarm_arm_home(self, code: str | None = None) -> None: async def async_alarm_arm_home(self, code: str | None = None) -> None:
@ -277,10 +285,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_ARM_REQUIRED] code_required: bool = self._config[CONF_CODE_ARM_REQUIRED]
if code_required and not self._validate_code(code, "arming home"): if code_required and not self._validate_code(code, "arming home"):
return return
action = self._config[CONF_PAYLOAD_ARM_HOME] action: str = self._config[CONF_PAYLOAD_ARM_HOME]
await self._publish(code, action) await self._publish(code, action)
async def async_alarm_arm_away(self, code: str | None = None) -> None: async def async_alarm_arm_away(self, code: str | None = None) -> None:
@ -288,10 +296,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_ARM_REQUIRED] code_required: bool = self._config[CONF_CODE_ARM_REQUIRED]
if code_required and not self._validate_code(code, "arming away"): if code_required and not self._validate_code(code, "arming away"):
return return
action = self._config[CONF_PAYLOAD_ARM_AWAY] action: str = self._config[CONF_PAYLOAD_ARM_AWAY]
await self._publish(code, action) await self._publish(code, action)
async def async_alarm_arm_night(self, code: str | None = None) -> None: async def async_alarm_arm_night(self, code: str | None = None) -> None:
@ -299,10 +307,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_ARM_REQUIRED] code_required: bool = self._config[CONF_CODE_ARM_REQUIRED]
if code_required and not self._validate_code(code, "arming night"): if code_required and not self._validate_code(code, "arming night"):
return return
action = self._config[CONF_PAYLOAD_ARM_NIGHT] action: str = self._config[CONF_PAYLOAD_ARM_NIGHT]
await self._publish(code, action) await self._publish(code, action)
async def async_alarm_arm_vacation(self, code: str | None = None) -> None: async def async_alarm_arm_vacation(self, code: str | None = None) -> None:
@ -310,10 +318,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_ARM_REQUIRED] code_required: bool = self._config[CONF_CODE_ARM_REQUIRED]
if code_required and not self._validate_code(code, "arming vacation"): if code_required and not self._validate_code(code, "arming vacation"):
return return
action = self._config[CONF_PAYLOAD_ARM_VACATION] action: str = self._config[CONF_PAYLOAD_ARM_VACATION]
await self._publish(code, action) await self._publish(code, action)
async def async_alarm_arm_custom_bypass(self, code: str | None = None) -> None: async def async_alarm_arm_custom_bypass(self, code: str | None = None) -> None:
@ -321,10 +329,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_ARM_REQUIRED] code_required: bool = self._config[CONF_CODE_ARM_REQUIRED]
if code_required and not self._validate_code(code, "arming custom bypass"): if code_required and not self._validate_code(code, "arming custom bypass"):
return return
action = self._config[CONF_PAYLOAD_ARM_CUSTOM_BYPASS] action: str = self._config[CONF_PAYLOAD_ARM_CUSTOM_BYPASS]
await self._publish(code, action) await self._publish(code, action)
async def async_alarm_trigger(self, code: str | None = None) -> None: async def async_alarm_trigger(self, code: str | None = None) -> None:
@ -332,13 +340,13 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
This method is a coroutine. This method is a coroutine.
""" """
code_required = self._config[CONF_CODE_TRIGGER_REQUIRED] code_required: bool = self._config[CONF_CODE_TRIGGER_REQUIRED]
if code_required and not self._validate_code(code, "triggering"): if code_required and not self._validate_code(code, "triggering"):
return return
action = self._config[CONF_PAYLOAD_TRIGGER] action: str = self._config[CONF_PAYLOAD_TRIGGER]
await self._publish(code, action) await self._publish(code, action)
async def _publish(self, code, action): async def _publish(self, code: str | None, action: str) -> None:
"""Publish via mqtt.""" """Publish via mqtt."""
variables = {"action": action, "code": code} variables = {"action": action, "code": code}
payload = self._command_template(None, variables=variables) payload = self._command_template(None, variables=variables)
@ -350,10 +358,10 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
self._config[CONF_ENCODING], self._config[CONF_ENCODING],
) )
def _validate_code(self, code, state): def _validate_code(self, code: str | None, state: str) -> bool:
"""Validate given code.""" """Validate given code."""
conf_code = self._config.get(CONF_CODE) conf_code: str | None = self._config.get(CONF_CODE)
check = ( check = bool(
conf_code is None conf_code is None
or code == conf_code or code == conf_code
or (conf_code == REMOTE_CODE and code) or (conf_code == REMOTE_CODE and code)

View File

@ -1,9 +1,10 @@
"""Support for MQTT binary sensors.""" """Support for MQTT binary sensors."""
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import datetime, timedelta
import functools import functools
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -24,7 +25,7 @@ from homeassistant.const import (
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
STATE_UNKNOWN, STATE_UNKNOWN,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
import homeassistant.helpers.event as evt import homeassistant.helpers.event as evt
@ -45,7 +46,7 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttValueTemplate from .models import MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data from .util import get_mqtt_data
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -111,8 +112,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT binary sensor.""" """Set up the MQTT binary sensor."""
async_add_entities([MqttBinarySensor(hass, config, config_entry, discovery_data)]) async_add_entities([MqttBinarySensor(hass, config, config_entry, discovery_data)])
@ -122,16 +123,18 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
"""Representation a binary sensor that is updated by MQTT.""" """Representation a binary sensor that is updated by MQTT."""
_entity_id_format = binary_sensor.ENTITY_ID_FORMAT _entity_id_format = binary_sensor.ENTITY_ID_FORMAT
_expired: bool | None
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT binary sensor.""" """Initialize the MQTT binary sensor."""
self._expiration_trigger = None self._expiration_trigger: CALLBACK_TYPE | None = None
self._delay_listener = None self._delay_listener: CALLBACK_TYPE | None = None
expire_after = config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:
self._expired = True
else:
self._expired = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@ -146,7 +149,9 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
# MqttEntity.async_added_to_hass(), then we should not restore state # MqttEntity.async_added_to_hass(), then we should not restore state
and not self._expiration_trigger and not self._expiration_trigger
): ):
expiration_at = last_state.last_changed + timedelta(seconds=expire_after) expiration_at: datetime = last_state.last_changed + timedelta(
seconds=expire_after
)
if expiration_at < (time_now := dt_util.utcnow()): if expiration_at < (time_now := dt_util.utcnow()):
# Skip reactivating the binary_sensor # Skip reactivating the binary_sensor
_LOGGER.debug("Skip state recovery after reload for %s", self.entity_id) _LOGGER.debug("Skip state recovery after reload for %s", self.entity_id)
@ -174,24 +179,30 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
await MqttEntity.async_will_remove_from_hass(self) await MqttEntity.async_will_remove_from_hass(self)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config: ConfigType) -> None: def _setup_from_config(self, config: ConfigType) -> None:
self._attr_device_class = config.get(CONF_DEVICE_CLASS) """(Re)Setup the entity."""
expire_after: int | None = config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:
self._expired = True
else:
self._expired = None
self._attr_force_update = config[CONF_FORCE_UPDATE] self._attr_force_update = config[CONF_FORCE_UPDATE]
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
self._value_template = MqttValueTemplate( self._value_template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), self._config.get(CONF_VALUE_TEMPLATE),
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
def off_delay_listener(now): def off_delay_listener(now: datetime) -> None:
"""Switch device off after a delay.""" """Switch device off after a delay."""
self._delay_listener = None self._delay_listener = None
self._attr_is_on = False self._attr_is_on = False
@ -199,10 +210,10 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_message_received(msg): def state_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT state message.""" """Handle a new received MQTT state message."""
# auto-expire enabled? # auto-expire enabled?
expire_after = self._config.get(CONF_EXPIRE_AFTER) expire_after: int | None = self._config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0: if expire_after is not None and expire_after > 0:
@ -241,7 +252,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
else: # Payload is not for this entity else: # Payload is not for this entity
template_info = "" template_info = ""
if self._config.get(CONF_VALUE_TEMPLATE) is not None: if self._config.get(CONF_VALUE_TEMPLATE) is not None:
template_info = f", template output: '{payload}', with value template '{str(self._config.get(CONF_VALUE_TEMPLATE))}'" template_info = f", template output: '{str(payload)}', with value template '{str(self._config.get(CONF_VALUE_TEMPLATE))}'"
_LOGGER.info( _LOGGER.info(
"No matching payload found for entity: %s with state topic: %s. Payload: '%s'%s", "No matching payload found for entity: %s with state topic: %s. Payload: '%s'%s",
self._config[CONF_NAME], self._config[CONF_NAME],
@ -276,12 +287,12 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_): def _value_is_expired(self, *_: Any) -> None:
"""Triggered when value is expired.""" """Triggered when value is expired."""
self._expiration_trigger = None self._expiration_trigger = None
self._expired = True self._expired = True
@ -291,7 +302,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return true if the device is available and value has not expired.""" """Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER) expire_after: int | None = self._config.get(CONF_EXPIRE_AFTER)
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185 # mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined] return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired expire_after is None or not self._expired

View File

@ -90,8 +90,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT button.""" """Set up the MQTT button."""
async_add_entities([MqttButton(hass, config, config_entry, discovery_data)]) async_add_entities([MqttButton(hass, config, config_entry, discovery_data)])
@ -102,25 +102,31 @@ class MqttButton(MqttEntity, ButtonEntity):
_entity_id_format = button.ENTITY_ID_FORMAT _entity_id_format = button.ENTITY_ID_FORMAT
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT button.""" """Initialize the MQTT button."""
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._command_template = MqttCommandTemplate( self._command_template = MqttCommandTemplate(
config.get(CONF_COMMAND_TEMPLATE), entity=self config.get(CONF_COMMAND_TEMPLATE), entity=self
).async_render ).async_render
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@property @property