Improve MQTT type hints part 3 (#80542)
* Improve typing debug_info * Improve typing device_automation * Improve typing device_trigger * Improve typing fan * Additional type hints device_trigger * Set fan type hints to class level * Cleanup and mypy * Follow up and missed hint * Follow up commentpull/81479/head
parent
dcd1ab7ec3
commit
b3403d7fca
|
@ -28,7 +28,7 @@ def log_messages(
|
|||
|
||||
debug_info_entities = get_mqtt_data(hass).debug_info_entities
|
||||
|
||||
def _log_message(msg):
|
||||
def _log_message(msg: Any) -> None:
|
||||
"""Log message."""
|
||||
messages = debug_info_entities[entity_id]["subscriptions"][
|
||||
msg.subscribed_topic
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
"""Provides device automations for MQTT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import device_trigger
|
||||
from .config import MQTT_BASE_SCHEMA
|
||||
|
@ -20,14 +25,19 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend(
|
|||
).extend(MQTT_BASE_SCHEMA.schema)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry):
|
||||
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
||||
"""Set up MQTT device automation dynamically through MQTT discovery."""
|
||||
|
||||
setup = functools.partial(_async_setup_automation, hass, config_entry=config_entry)
|
||||
await async_setup_entry_helper(hass, "device_automation", setup, PLATFORM_SCHEMA)
|
||||
|
||||
|
||||
async def _async_setup_automation(hass, config, config_entry, discovery_data):
|
||||
async def _async_setup_automation(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType,
|
||||
) -> None:
|
||||
"""Set up an MQTT device automation."""
|
||||
if config[CONF_AUTOMATION_TYPE] == AUTOMATION_TYPE_TRIGGER:
|
||||
await device_trigger.async_setup_trigger(
|
||||
|
@ -35,6 +45,6 @@ async def _async_setup_automation(hass, config, config_entry, discovery_data):
|
|||
)
|
||||
|
||||
|
||||
async def async_removed_from_device(hass, device_id):
|
||||
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
|
||||
"""Handle Mqtt removed from a device."""
|
||||
await device_trigger.async_removed_from_device(hass, device_id)
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
|
@ -23,7 +23,7 @@ from homeassistant.exceptions import HomeAssistantError
|
|||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import debug_info, trigger as mqtt_trigger
|
||||
from .config import MQTT_BASE_SCHEMA
|
||||
|
@ -35,7 +35,7 @@ from .const import (
|
|||
CONF_TOPIC,
|
||||
DOMAIN,
|
||||
)
|
||||
from .discovery import MQTT_DISCOVERY_DONE
|
||||
from .discovery import MQTT_DISCOVERY_DONE, MQTTDiscoveryPayload
|
||||
from .mixins import (
|
||||
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
|
||||
MqttDiscoveryDeviceUpdate,
|
||||
|
@ -96,7 +96,7 @@ class TriggerInstance:
|
|||
|
||||
async def async_attach_trigger(self) -> None:
|
||||
"""Attach MQTT trigger."""
|
||||
mqtt_config = {
|
||||
mqtt_config: dict[str, Any] = {
|
||||
CONF_PLATFORM: DOMAIN,
|
||||
CONF_TOPIC: self.trigger.topic,
|
||||
CONF_ENCODING: DEFAULT_ENCODING,
|
||||
|
@ -123,7 +123,7 @@ class Trigger:
|
|||
"""Device trigger settings."""
|
||||
|
||||
device_id: str = attr.ib()
|
||||
discovery_data: dict | None = attr.ib()
|
||||
discovery_data: DiscoveryInfoType | None = attr.ib()
|
||||
hass: HomeAssistant = attr.ib()
|
||||
payload: str | None = attr.ib()
|
||||
qos: int | None = attr.ib()
|
||||
|
@ -193,7 +193,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
|||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
device_id: str,
|
||||
discovery_data: dict,
|
||||
discovery_data: DiscoveryInfoType,
|
||||
config_entry: ConfigEntry,
|
||||
) -> None:
|
||||
"""Initialize."""
|
||||
|
@ -237,7 +237,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
|||
self.hass, discovery_hash, self.discovery_data, self.device_id
|
||||
)
|
||||
|
||||
async def async_update(self, discovery_data: dict) -> None:
|
||||
async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None:
|
||||
"""Handle MQTT device trigger discovery updates."""
|
||||
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
|
||||
discovery_id = discovery_hash[1]
|
||||
|
@ -261,11 +261,14 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
|
|||
|
||||
|
||||
async def async_setup_trigger(
|
||||
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType,
|
||||
) -> None:
|
||||
"""Set up the MQTT device trigger."""
|
||||
config = TRIGGER_DISCOVERY_SCHEMA(config)
|
||||
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
|
||||
if (device_id := update_device(hass, config_entry, config)) is None:
|
||||
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Support for MQTT fans."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
|
@ -27,6 +28,7 @@ from homeassistant.const import (
|
|||
from homeassistant.core import HomeAssistant, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.util.percentage import (
|
||||
int_states_in_range,
|
||||
|
@ -54,7 +56,13 @@ from .mixins import (
|
|||
async_setup_platform_helper,
|
||||
warn_for_legacy_schema,
|
||||
)
|
||||
from .models import MqttCommandTemplate, MqttValueTemplate
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
PublishPayloadType,
|
||||
ReceiveMessage,
|
||||
ReceivePayloadType,
|
||||
)
|
||||
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
|
||||
|
||||
CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic"
|
||||
|
@ -110,18 +118,18 @@ MQTT_FAN_ATTRIBUTES_BLOCKED = frozenset(
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def valid_speed_range_configuration(config):
|
||||
def valid_speed_range_configuration(config: ConfigType) -> ConfigType:
|
||||
"""Validate that the fan speed_range configuration is valid, throws if it isn't."""
|
||||
if config.get(CONF_SPEED_RANGE_MIN) == 0:
|
||||
if config[CONF_SPEED_RANGE_MIN] == 0:
|
||||
raise ValueError("speed_range_min must be > 0")
|
||||
if config.get(CONF_SPEED_RANGE_MIN) >= config.get(CONF_SPEED_RANGE_MAX):
|
||||
if config[CONF_SPEED_RANGE_MIN] >= config[CONF_SPEED_RANGE_MAX]:
|
||||
raise ValueError("speed_range_max must be > speed_range_min")
|
||||
return config
|
||||
|
||||
|
||||
def valid_preset_mode_configuration(config):
|
||||
def valid_preset_mode_configuration(config: ConfigType) -> ConfigType:
|
||||
"""Validate that the preset mode reset payload is not one of the preset modes."""
|
||||
if config.get(CONF_PAYLOAD_RESET_PRESET_MODE) in config.get(CONF_PRESET_MODES_LIST):
|
||||
if config[CONF_PAYLOAD_RESET_PRESET_MODE] in config[CONF_PRESET_MODES_LIST]:
|
||||
raise ValueError("preset_modes must not contain payload_reset_preset_mode")
|
||||
return config
|
||||
|
||||
|
@ -250,8 +258,8 @@ async def _async_setup_entity(
|
|||
hass: HomeAssistant,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry | None = None,
|
||||
discovery_data: dict | None = None,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the MQTT fan."""
|
||||
async_add_entities([MqttFan(hass, config, config_entry, discovery_data)])
|
||||
|
@ -263,32 +271,41 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
_entity_id_format = fan.ENTITY_ID_FORMAT
|
||||
_attributes_extra_blocked = MQTT_FAN_ATTRIBUTES_BLOCKED
|
||||
|
||||
def __init__(self, hass, config, config_entry, discovery_data):
|
||||
_command_templates: dict[str, Callable[[PublishPayloadType], PublishPayloadType]]
|
||||
_value_templates: dict[str, Callable[[ReceivePayloadType], ReceivePayloadType]]
|
||||
_feature_percentage: bool
|
||||
_feature_preset_mode: bool
|
||||
_topic: dict[str, Any]
|
||||
_optimistic: bool
|
||||
_optimistic_oscillation: bool
|
||||
_optimistic_percentage: bool
|
||||
_optimistic_preset_mode: bool
|
||||
_payload: dict[str, Any]
|
||||
_speed_range: tuple[int, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None,
|
||||
) -> None:
|
||||
"""Initialize the MQTT fan."""
|
||||
self._attr_percentage = None
|
||||
self._attr_preset_mode = None
|
||||
|
||||
self._topic = None
|
||||
self._payload = None
|
||||
self._value_templates = None
|
||||
self._command_templates = None
|
||||
self._optimistic = None
|
||||
self._optimistic_oscillation = None
|
||||
self._optimistic_percentage = None
|
||||
self._optimistic_preset_mode = None
|
||||
|
||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||
|
||||
@staticmethod
|
||||
def config_schema():
|
||||
def config_schema() -> vol.Schema:
|
||||
"""Return the config schema."""
|
||||
return DISCOVERY_SCHEMA
|
||||
|
||||
def _setup_from_config(self, config):
|
||||
def _setup_from_config(self, config: ConfigType) -> None:
|
||||
"""(Re)Setup the entity."""
|
||||
self._speed_range = (
|
||||
config.get(CONF_SPEED_RANGE_MIN),
|
||||
config.get(CONF_SPEED_RANGE_MAX),
|
||||
config[CONF_SPEED_RANGE_MIN],
|
||||
config[CONF_SPEED_RANGE_MAX],
|
||||
)
|
||||
self._topic = {
|
||||
key: config.get(key)
|
||||
|
@ -303,18 +320,6 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
CONF_OSCILLATION_COMMAND_TOPIC,
|
||||
)
|
||||
}
|
||||
self._value_templates = {
|
||||
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
|
||||
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
|
||||
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
|
||||
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
|
||||
}
|
||||
self._command_templates = {
|
||||
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
|
||||
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
|
||||
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
|
||||
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
|
||||
}
|
||||
self._payload = {
|
||||
"STATE_ON": config[CONF_PAYLOAD_ON],
|
||||
"STATE_OFF": config[CONF_PAYLOAD_OFF],
|
||||
|
@ -359,24 +364,38 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
if self._feature_preset_mode:
|
||||
self._attr_supported_features |= FanEntityFeature.PRESET_MODE
|
||||
|
||||
for key, tpl in self._command_templates.items():
|
||||
command_templates: dict[str, Template | None] = {
|
||||
CONF_STATE: config.get(CONF_COMMAND_TEMPLATE),
|
||||
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_COMMAND_TEMPLATE),
|
||||
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_COMMAND_TEMPLATE),
|
||||
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_COMMAND_TEMPLATE),
|
||||
}
|
||||
self._command_templates = {}
|
||||
for key, tpl in command_templates.items():
|
||||
self._command_templates[key] = MqttCommandTemplate(
|
||||
tpl, entity=self
|
||||
).async_render
|
||||
|
||||
for key, tpl in self._value_templates.items():
|
||||
self._value_templates = {}
|
||||
value_templates: dict[str, Template | None] = {
|
||||
CONF_STATE: config.get(CONF_STATE_VALUE_TEMPLATE),
|
||||
ATTR_PERCENTAGE: config.get(CONF_PERCENTAGE_VALUE_TEMPLATE),
|
||||
ATTR_PRESET_MODE: config.get(CONF_PRESET_MODE_VALUE_TEMPLATE),
|
||||
ATTR_OSCILLATING: config.get(CONF_OSCILLATION_VALUE_TEMPLATE),
|
||||
}
|
||||
for key, tpl in value_templates.items():
|
||||
self._value_templates[key] = MqttValueTemplate(
|
||||
tpl,
|
||||
entity=self,
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
def _prepare_subscribe_topics(self):
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics = {}
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def state_received(msg):
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||
if not payload:
|
||||
|
@ -400,7 +419,7 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def percentage_received(msg):
|
||||
def percentage_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the percentage."""
|
||||
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
|
||||
msg.payload
|
||||
|
@ -446,9 +465,9 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def preset_mode_received(msg):
|
||||
def preset_mode_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for preset mode."""
|
||||
preset_mode = self._value_templates[ATTR_PRESET_MODE](msg.payload)
|
||||
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
|
||||
if preset_mode == self._payload["PRESET_MODE_RESET"]:
|
||||
self._attr_preset_mode = None
|
||||
self.async_write_ha_state()
|
||||
|
@ -456,7 +475,7 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
if not preset_mode:
|
||||
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
|
||||
return
|
||||
if preset_mode not in self.preset_modes:
|
||||
if not self.preset_modes or preset_mode not in self.preset_modes:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid preset mode",
|
||||
msg.payload,
|
||||
|
@ -479,7 +498,7 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def oscillation_received(msg):
|
||||
def oscillation_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the oscillation."""
|
||||
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
|
||||
if not payload:
|
||||
|
@ -504,7 +523,7 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
self.hass, self._sub_state, topics
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
|
||||
|
|
Loading…
Reference in New Issue