From ff4456cb295f52d432db166d474d567299d98b39 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Mon, 7 Nov 2022 08:24:49 +0100 Subject: [PATCH] Improve MQTT type hints part 4 (#80971) * Improve typing humidifier * Improve typing lock * Improve typing number * Set humidifier type hints at class level * Set lock type hints at class level * Set number type hints at class level * Some small updates * Follow up comment * Remove assert --- homeassistant/components/mqtt/humidifier.py | 100 ++++++++++++-------- homeassistant/components/mqtt/lock.py | 58 ++++++------ homeassistant/components/mqtt/number.py | 66 ++++++++----- 3 files changed, 128 insertions(+), 96 deletions(-) diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index e3e94c07dae..69b6d3e3e89 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -1,6 +1,7 @@ """Support for MQTT humidifiers.""" from __future__ import annotations +from collections.abc import Callable import functools import logging from typing import Any @@ -28,6 +29,7 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.template import Template from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import subscription @@ -50,7 +52,13 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttCommandTemplate, MqttValueTemplate +from .models import ( + MqttCommandTemplate, + MqttValueTemplate, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic CONF_AVAILABLE_MODES_LIST = "modes" @@ -87,18 +95,18 @@ MQTT_HUMIDIFIER_ATTRIBUTES_BLOCKED = frozenset( _LOGGER = logging.getLogger(__name__) -def valid_mode_configuration(config): +def valid_mode_configuration(config: ConfigType) -> ConfigType: """Validate that the mode reset payload is not one of the available modes.""" - if config.get(CONF_PAYLOAD_RESET_MODE) in config.get(CONF_AVAILABLE_MODES_LIST): + if config[CONF_PAYLOAD_RESET_MODE] in config[CONF_AVAILABLE_MODES_LIST]: raise ValueError("modes must not contain payload_reset_mode") return config -def valid_humidity_range_configuration(config): +def valid_humidity_range_configuration(config: ConfigType) -> ConfigType: """Validate that the target_humidity range configuration is valid, throws if it isn't.""" - if config.get(CONF_TARGET_HUMIDITY_MIN) >= config.get(CONF_TARGET_HUMIDITY_MAX): + if config[CONF_TARGET_HUMIDITY_MIN] >= config[CONF_TARGET_HUMIDITY_MAX]: raise ValueError("target_humidity_max must be > target_humidity_min") - if config.get(CONF_TARGET_HUMIDITY_MAX) > 100: + if config[CONF_TARGET_HUMIDITY_MAX] > 100: raise ValueError("max_humidity must be <= 100") return config @@ -196,8 +204,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT humidifier.""" async_add_entities([MqttHumidifier(hass, config, config_entry, discovery_data)]) @@ -209,30 +217,36 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): _entity_id_format = humidifier.ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_HUMIDIFIER_ATTRIBUTES_BLOCKED - def __init__(self, hass, config, config_entry, discovery_data): + _command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]] + _value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]] + _optimistic: bool + _optimistic_target_humidity: bool + _optimistic_mode: bool + _payload: dict[str, str] + _topic: dict[str, Any] + + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT humidifier.""" self._attr_mode = None - self._topic = None - self._payload = None - self._value_templates = None - self._command_templates = None - self._optimistic = None - self._optimistic_target_humidity = None - self._optimistic_mode = None - MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._attr_device_class = config.get(CONF_DEVICE_CLASS) - self._attr_min_humidity = config.get(CONF_TARGET_HUMIDITY_MIN) - self._attr_max_humidity = config.get(CONF_TARGET_HUMIDITY_MAX) + self._attr_min_humidity = config[CONF_TARGET_HUMIDITY_MIN] + self._attr_max_humidity = config[CONF_TARGET_HUMIDITY_MAX] self._topic = { key: config.get(key) @@ -245,16 +259,6 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): CONF_MODE_COMMAND_TOPIC, ) } - self._value_templates = { - CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), - ATTR_HUMIDITY: config.get(CONF_TARGET_HUMIDITY_STATE_TEMPLATE), - ATTR_MODE: config.get(CONF_MODE_STATE_TEMPLATE), - } - self._command_templates = { - CONF_STATE: config.get(CONF_COMMAND_TEMPLATE), - ATTR_HUMIDITY: config.get(CONF_TARGET_HUMIDITY_COMMAND_TEMPLATE), - ATTR_MODE: config.get(CONF_MODE_COMMAND_TEMPLATE), - } self._payload = { "STATE_ON": config[CONF_PAYLOAD_ON], "STATE_OFF": config[CONF_PAYLOAD_OFF], @@ -270,31 +274,43 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): else: self._attr_supported_features = 0 - optimistic = config[CONF_OPTIMISTIC] + optimistic: bool = config[CONF_OPTIMISTIC] self._optimistic = optimistic or self._topic[CONF_STATE_TOPIC] is None self._optimistic_target_humidity = ( optimistic or self._topic[CONF_TARGET_HUMIDITY_STATE_TOPIC] is None ) self._optimistic_mode = optimistic or self._topic[CONF_MODE_STATE_TOPIC] is None - for key, tpl in self._command_templates.items(): + self._command_templates = {} + command_templates: dict[str, Template | None] = { + CONF_STATE: config.get(CONF_COMMAND_TEMPLATE), + ATTR_HUMIDITY: config.get(CONF_TARGET_HUMIDITY_COMMAND_TEMPLATE), + ATTR_MODE: config.get(CONF_MODE_COMMAND_TEMPLATE), + } + for key, tpl in command_templates.items(): self._command_templates[key] = MqttCommandTemplate( tpl, entity=self ).async_render - for key, tpl in self._value_templates.items(): + self._value_templates = {} + value_templates: dict[str, Template | None] = { + CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE), + ATTR_HUMIDITY: config.get(CONF_TARGET_HUMIDITY_STATE_TEMPLATE), + ATTR_MODE: config.get(CONF_MODE_STATE_TEMPLATE), + } + for key, tpl in value_templates.items(): self._value_templates[key] = MqttValueTemplate( tpl, entity=self, ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - topics = {} + topics: dict[str, Any] = {} @callback @log_messages(self.hass, self.entity_id) - def state_received(msg): + def state_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message.""" payload = self._value_templates[CONF_STATE](msg.payload) if not payload: @@ -318,7 +334,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): @callback @log_messages(self.hass, self.entity_id) - def target_humidity_received(msg): + def target_humidity_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for the target humidity.""" rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY]( msg.payload @@ -365,9 +381,9 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): @callback @log_messages(self.hass, self.entity_id) - def mode_received(msg): + def mode_received(msg: ReceiveMessage) -> None: """Handle new received MQTT message for mode.""" - mode = self._value_templates[ATTR_MODE](msg.payload) + mode = str(self._value_templates[ATTR_MODE](msg.payload)) if mode == self._payload["MODE_RESET"]: self._attr_mode = None get_mqtt_data(self.hass).state_write_requests.write_state_request(self) @@ -375,7 +391,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): if not mode: _LOGGER.debug("Ignoring empty mode from '%s'", msg.topic) return - if mode not in self.available_modes: + if not self.available_modes or mode not in self.available_modes: _LOGGER.warning( "'%s' received on topic %s. '%s' is not a valid mode", msg.payload, @@ -400,7 +416,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): self.hass, self._sub_state, topics ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index c9bdd696896..e141fcbd693 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -1,6 +1,7 @@ """Support for MQTT locks.""" from __future__ import annotations +from collections.abc import Callable import functools from typing import Any @@ -32,7 +33,7 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttValueTemplate +from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType from .util import get_mqtt_data CONF_PAYLOAD_LOCK = "payload_lock" @@ -112,8 +113,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT Lock platform.""" async_add_entities([MqttLock(hass, config, config_entry, discovery_data)]) @@ -125,39 +126,50 @@ class MqttLock(MqttEntity, LockEntity): _entity_id_format = lock.ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_LOCK_ATTRIBUTES_BLOCKED - def __init__(self, hass, config, config_entry, discovery_data): - """Initialize the lock.""" - self._state = False - self._optimistic = False + _optimistic: bool + _value_template: Callable[[ReceivePayloadType], ReceivePayloadType] + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: + """Initialize the lock.""" + self._attr_is_locked = False MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._optimistic = config[CONF_OPTIMISTIC] self._value_template = MqttValueTemplate( - self._config.get(CONF_VALUE_TEMPLATE), + config.get(CONF_VALUE_TEMPLATE), entity=self, ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + self._attr_supported_features = ( + LockEntityFeature.OPEN if CONF_PAYLOAD_OPEN in config else 0 + ) + + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def message_received(msg): + def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" payload = self._value_template(msg.payload) if payload == self._config[CONF_STATE_LOCKED]: - self._state = True + self._attr_is_locked = True elif payload == self._config[CONF_STATE_UNLOCKED]: - self._state = False + self._attr_is_locked = False get_mqtt_data(self.hass).state_write_requests.write_state_request(self) @@ -178,25 +190,15 @@ class MqttLock(MqttEntity, LockEntity): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) - @property - def is_locked(self) -> bool: - """Return true if lock is locked.""" - return self._state - @property def assumed_state(self) -> bool: """Return true if we do optimistic updates.""" return self._optimistic - @property - def supported_features(self) -> int: - """Flag supported features.""" - return LockEntityFeature.OPEN if CONF_PAYLOAD_OPEN in self._config else 0 - async def async_lock(self, **kwargs: Any) -> None: """Lock the device. @@ -211,7 +213,7 @@ class MqttLock(MqttEntity, LockEntity): ) if self._optimistic: # Optimistically assume that the lock has changed state. - self._state = True + self._attr_is_locked = True self.async_write_ha_state() async def async_unlock(self, **kwargs: Any) -> None: @@ -228,7 +230,7 @@ class MqttLock(MqttEntity, LockEntity): ) if self._optimistic: # Optimistically assume that the lock has changed state. - self._state = False + self._attr_is_locked = False self.async_write_ha_state() async def async_open(self, **kwargs: Any) -> None: @@ -245,5 +247,5 @@ class MqttLock(MqttEntity, LockEntity): ) if self._optimistic: # Optimistically assume that the lock unlocks when opened. - self._state = False + self._attr_is_locked = False self.async_write_ha_state() diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 95dbe970430..51b480c036e 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -1,6 +1,7 @@ """Configure number in a device through MQTT topic.""" from __future__ import annotations +from collections.abc import Callable import functools import logging @@ -47,7 +48,13 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttCommandTemplate, MqttValueTemplate +from .models import ( + MqttCommandTemplate, + MqttValueTemplate, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -70,9 +77,9 @@ MQTT_NUMBER_ATTRIBUTES_BLOCKED = frozenset( ) -def validate_config(config): +def validate_config(config: ConfigType) -> ConfigType: """Validate that the configuration is valid, throws if it isn't.""" - if config.get(CONF_MIN) >= config.get(CONF_MAX): + if config[CONF_MIN] >= config[CONF_MAX]: raise vol.Invalid(f"'{CONF_MAX}' must be > '{CONF_MIN}'") return config @@ -147,8 +154,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT number.""" async_add_entities([MqttNumber(hass, config, config_entry, discovery_data)]) @@ -160,33 +167,39 @@ class MqttNumber(MqttEntity, RestoreNumber): _entity_id_format = number.ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_NUMBER_ATTRIBUTES_BLOCKED - def __init__(self, hass, config, config_entry, discovery_data): - """Initialize the MQTT Number.""" - self._config = config - self._optimistic = False - self._sub_state = None + _optimistic: bool + _command_template: Callable[[PublishPayloadType], PublishPayloadType] + _value_template: Callable[[ReceivePayloadType], ReceivePayloadType] + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: + """Initialize the MQTT Number.""" RestoreNumber.__init__(self) MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" + self._config = config self._optimistic = config[CONF_OPTIMISTIC] - self._templates = { - CONF_COMMAND_TEMPLATE: MqttCommandTemplate( - config.get(CONF_COMMAND_TEMPLATE), entity=self - ).async_render, - CONF_VALUE_TEMPLATE: MqttValueTemplate( - config.get(CONF_VALUE_TEMPLATE), - entity=self, - ).async_render_with_possible_json_value, - } + self._command_template = MqttCommandTemplate( + config.get(CONF_COMMAND_TEMPLATE), entity=self + ).async_render + self._value_template = MqttValueTemplate( + config.get(CONF_VALUE_TEMPLATE), + entity=self, + ).async_render_with_possible_json_value + self._attr_device_class = config.get(CONF_DEVICE_CLASS) self._attr_mode = config[CONF_MODE] self._attr_native_max_value = config[CONF_MAX] @@ -194,14 +207,15 @@ class MqttNumber(MqttEntity, RestoreNumber): self._attr_native_step = config[CONF_STEP] self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def message_received(msg): + def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" - payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) + num_value: int | float | None + payload = str(self._value_template(msg.payload)) try: if payload == self._config[CONF_PAYLOAD_RESET]: num_value = None @@ -245,7 +259,7 @@ class MqttNumber(MqttEntity, RestoreNumber): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) @@ -260,7 +274,7 @@ class MqttNumber(MqttEntity, RestoreNumber): if value.is_integer(): current_number = int(value) - payload = self._templates[CONF_COMMAND_TEMPLATE](current_number) + payload = self._command_template(current_number) if self._optimistic: self._attr_native_value = current_number