Improve MQTT type hints part 3 (#80542)

* Improve typing debug_info

* Improve typing device_automation

* Improve typing device_trigger

* Improve typing fan

* Additional type hints device_trigger

* Set fan type hints to class level

* Cleanup and mypy

* Follow up and missed hint

* Follow up comment
pull/81479/head
Jan Bouwhuis 2022-11-03 13:06:53 +01:00 committed by GitHub
parent dcd1ab7ec3
commit b3403d7fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 58 deletions

View File

@ -28,7 +28,7 @@ def log_messages(
debug_info_entities = get_mqtt_data(hass).debug_info_entities
def _log_message(msg):
def _log_message(msg: Any) -> None:
"""Log message."""
messages = debug_info_entities[entity_id]["subscriptions"][
msg.subscribed_topic

View File

@ -1,9 +1,14 @@
"""Provides device automations for MQTT."""
from __future__ import annotations
import functools
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import device_trigger
from .config import MQTT_BASE_SCHEMA
@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
).extend(MQTT_BASE_SCHEMA.schema)
async def async_setup_entry(hass, config_entry):
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Set up MQTT device automation dynamically through MQTT discovery."""
setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry)
await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA)
async def _async_setup_automation(hass, config, config_entry, discovery_data):
async def _async_setup_automation(
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType,
) -> None:
"""Set up an MQTT device automation."""
if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
await device_trigger.async_setup_trigger(
@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data):
)
async def async_removed_from_device(hass, device_id):
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
"""Handle Mqtt removed from a device."""
await device_trigger.async_removed_from_device(hass, device_id)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable
import logging
from typing import cast
from typing import Any, cast
import attr
import voluptuous as vol
@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, trigger as mqtt_trigger
from .config import MQTT_BASE_SCHEMA
@ -35,7 +35,7 @@ from .const import (
CONF_TOPIC,
DOMAIN,
)
from .discovery import MQTT_DISCOVERY_DONE
from .discovery import MQTT_DISCOVERY_DONE, MQTTDiscoveryPayload
from .mixins import (
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
MqttDiscoveryDeviceUpdate,
@ -96,7 +96,7 @@ class TriggerInstance:
async def async_attach_trigger(self) -> None:
"""Attach MQTT trigger."""
mqtt_config = {
mqtt_config: dict[str, Any] = {
CONF_PLATFORM: DOMAIN,
CONF_TOPIC: self.trigger.topic,
CONF_ENCODING: DEFAULT_ENCODING,
@ -123,7 +123,7 @@ class Trigger:
"""Device trigger settings."""
device_id: str = attr.ib()
discovery_data: dict | None = attr.ib()
discovery_data: DiscoveryInfoType | None = attr.ib()
hass: HomeAssistant = attr.ib()
payload: str | None = attr.ib()
qos: int | None = attr.ib()
@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
hass: HomeAssistant,
config: ConfigType,
device_id: str,
discovery_data: dict,
discovery_data: DiscoveryInfoType,
config_entry: ConfigEntry,
) -> None:
"""Initialize."""
@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
self.hass, discovery_hash, self.discovery_data, self.device_id
)
async def async_update(self, discovery_data: dict) -> None:
async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None:
"""Handle MQTT device trigger discovery updates."""
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
discovery_id = discovery_hash[1]
@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
async def async_setup_trigger(
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType,
) -> None:
"""Set up the MQTT device trigger."""
config = TRIGGER_DISCOVERY_SCHEMA(config)
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
if (device_id := update_device(hass, config_entry, config)) is None:
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)

View File

@ -1,6 +1,7 @@
"""Support for MQTT fans."""
from __future__ import annotations
from collections.abc import Callable
import functools
import logging
import math
@ -27,6 +28,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.util.percentage import (
int_states_in_range,
@ -54,7 +56,13 @@ from .mixins import (
async_setup_platform_helper,
warn_for_legacy_schema,
)
from .models import MqttCommandTemplate, MqttValueTemplate
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic"
@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset(
_LOGGER = logging.getLogger(__name__)
def valid_speed_range_configuration(config):
def valid_speed_range_configuration(config: ConfigType) -> ConfigType:
"""Validate that the fan speed_range configuration is valid, throws if it isn't."""
if config.get(CONF_SPEED_RANGE_MIN) == 0:
if config[CONF_SPEED_RANGE_MIN] == 0:
raise ValueError("speed_range_min must be > 0")
if config.get(CONF_SPEED_RANGE_MIN) >= config.get(CONF_SPEED_RANGE_MAX):
if config[CONF_SPEED_RANGE_MIN] >= config[CONF_SPEED_RANGE_MAX]:
raise ValueError("speed_range_max must be > speed_range_min")
return config
def valid_preset_mode_configuration(config):
def valid_preset_mode_configuration(config: ConfigType) -> ConfigType:
"""Validate that the preset mode reset payload is not one of the preset modes."""
if config.get(CONF_PAYLOAD_RESET_PRESET_MODE) in config.get(CONF_PRESET_MODES_LIST):
if config[CONF_PAYLOAD_RESET_PRESET_MODE] in config[CONF_PRESET_MODES_LIST]:
raise ValueError("preset_modes must not contain payload_reset_preset_mode")
return config
@ -250,8 +258,8 @@ async def _async_setup_entity(
hass: HomeAssistant,
async_add_entities: AddEntitiesCallback,
config: ConfigType,
config_entry: ConfigEntry | None = None,
discovery_data: dict | None = None,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None = None,
) -> None:
"""Set up the MQTT fan."""
async_add_entities([MqttFan(hass, config, config_entry, discovery_data)])
@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity):
_entity_id_format = fan.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data):
_command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]]
_value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]]
_feature_percentage: bool
_feature_preset_mode: bool
_topic: dict[str, Any]
_optimistic: bool
_optimistic_oscillation: bool
_optimistic_percentage: bool
_optimistic_preset_mode: bool
_payload: dict[str, Any]
_speed_range: tuple[int, int]
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT fan."""
self._attr_percentage = None
self._attr_preset_mode = None
self._topic = None
self._payload = None
self._value_templates = None
self._command_templates = None
self._optimistic = None
self._optimistic_oscillation = None
self._optimistic_percentage = None
self._optimistic_preset_mode = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod
def config_schema():
def config_schema() -> vol.Schema:
"""Return the config schema."""
return DISCOVERY_SCHEMA
def _setup_from_config(self, config):
def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity."""
self._speed_range = (
config.get(CONF_SPEED_RANGE_MIN),
config.get(CONF_SPEED_RANGE_MAX),
config[CONF_SPEED_RANGE_MIN],
config[CONF_SPEED_RANGE_MAX],
)
self._topic = {
key: config.get(key)
@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity):
CONF_OSCILLATION_COMMAND_TOPIC,
)
}
self._value_templates = {
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
}
self._command_templates = {
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
}
self._payload = {
"STATE_ON": config[CONF_PAYLOAD_ON],
"STATE_OFF": config[CONF_PAYLOAD_OFF],
@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity):
if self._feature_preset_mode:
self._attr_supported_features |= FanEntityFeature.PRESET_MODE
for key, tpl in self._command_templates.items():
command_templates: dict[str, Template | None] = {
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
}
self._command_templates = {}
for key, tpl in command_templates.items():
self._command_templates[key] = MqttCommandTemplate(
tpl, entity=self
).async_render
for key, tpl in self._value_templates.items():
self._value_templates = {}
value_templates: dict[str, Template | None] = {
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
}
for key, tpl in value_templates.items():
self._value_templates[key] = MqttValueTemplate(
tpl,
entity=self,
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self):
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
topics: dict[str, Any] = {}
@callback
@log_messages(self.hass, self.entity_id)
def state_received(msg):
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity):
@callback
@log_messages(self.hass, self.entity_id)
def percentage_received(msg):
def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload
@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity):
@callback
@log_messages(self.hass, self.entity_id)
def preset_mode_received(msg):
def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = self._value_templates[ATTR_PRESET_MODE](msg.payload)
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
self.async_write_ha_state()
@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity):
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if preset_mode not in self.preset_modes:
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity):
@callback
@log_messages(self.hass, self.entity_id)
def oscillation_received(msg):
def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload:
@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity):
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)