Strict type hints for MQTT integration (#82317)
* Strict type hints for MQTT integration * Fix errors * Additional corrections * Use cv.template to avoid untyped calls * Enable strict typing policy for MQTT integration * Use ignore[no-untyped-call] * Use # type: ignore[unreachable] * Correct cast * Refactor getting discovery_payload * Remove unused type ignore commentspull/82683/head
parent
697b5db3f2
commit
8a8732f0bc
|
@ -187,6 +187,7 @@ homeassistant.components.mjpeg.*
|
|||
homeassistant.components.modbus.*
|
||||
homeassistant.components.modem_callerid.*
|
||||
homeassistant.components.moon.*
|
||||
homeassistant.components.mqtt.*
|
||||
homeassistant.components.mysensors.*
|
||||
homeassistant.components.nam.*
|
||||
homeassistant.components.nanoleaf.*
|
||||
|
|
|
@ -249,7 +249,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
|||
@property
|
||||
def code_arm_required(self) -> bool:
|
||||
"""Whether the code is required for arm actions."""
|
||||
return self._config[CONF_CODE_ARM_REQUIRED]
|
||||
return bool(self._config[CONF_CODE_ARM_REQUIRED])
|
||||
|
||||
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
||||
"""Send disarm command.
|
||||
|
|
|
@ -80,7 +80,6 @@ if TYPE_CHECKING:
|
|||
# because integrations should be able to optionally rely on MQTT.
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DISCOVERY_COOLDOWN = 2
|
||||
|
@ -148,16 +147,19 @@ AsyncDeprecatedMessageCallbackType = Callable[
|
|||
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
|
||||
]
|
||||
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
|
||||
DeprecatedMessageCallbackTypes = Union[
|
||||
AsyncDeprecatedMessageCallbackType, DeprecatedMessageCallbackType
|
||||
]
|
||||
|
||||
|
||||
def wrap_msg_callback(
|
||||
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
|
||||
msg_callback: DeprecatedMessageCallbackTypes,
|
||||
) -> AsyncMessageCallbackType | MessageCallbackType:
|
||||
"""Wrap an MQTT message callback to support deprecated signature."""
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = msg_callback
|
||||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func
|
||||
check_func = check_func.func # type: ignore[unreachable]
|
||||
|
||||
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
|
@ -170,14 +172,15 @@ def wrap_msg_callback(
|
|||
)
|
||||
|
||||
wrapper_func = async_wrapper
|
||||
else:
|
||||
return wrapper_func
|
||||
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: ReceiveMessage) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
|
||||
wrapper_func = wrapper
|
||||
|
||||
wrapper_func = wrapper
|
||||
return wrapper_func
|
||||
|
||||
|
||||
|
@ -187,8 +190,7 @@ async def async_subscribe(
|
|||
topic: str,
|
||||
msg_callback: AsyncMessageCallbackType
|
||||
| MessageCallbackType
|
||||
| DeprecatedMessageCallbackType
|
||||
| AsyncDeprecatedMessageCallbackType,
|
||||
| DeprecatedMessageCallbackTypes,
|
||||
qos: int = DEFAULT_QOS,
|
||||
encoding: str | None = DEFAULT_ENCODING,
|
||||
) -> CALLBACK_TYPE:
|
||||
|
@ -219,7 +221,7 @@ async def async_subscribe(
|
|||
msg_callback.__name__,
|
||||
)
|
||||
wrapped_msg_callback = wrap_msg_callback(
|
||||
cast(DeprecatedMessageCallbackType, msg_callback)
|
||||
cast(DeprecatedMessageCallbackTypes, msg_callback)
|
||||
)
|
||||
|
||||
async_remove = await mqtt_data.client.async_subscribe(
|
||||
|
|
|
@ -97,7 +97,7 @@ async def async_start( # noqa: C901
|
|||
mqtt_data = get_mqtt_data(hass)
|
||||
mqtt_integrations = {}
|
||||
|
||||
async def async_discovery_message_received(msg) -> None:
|
||||
async def async_discovery_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Process the received message."""
|
||||
mqtt_data.last_discovery = time.time()
|
||||
payload = msg.payload
|
||||
|
@ -122,46 +122,50 @@ async def async_start( # noqa: C901
|
|||
|
||||
if payload:
|
||||
try:
|
||||
payload = json_loads(payload)
|
||||
discovery_payload = MQTTDiscoveryPayload(json_loads(payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
|
||||
return
|
||||
else:
|
||||
discovery_payload = MQTTDiscoveryPayload({})
|
||||
|
||||
payload = MQTTDiscoveryPayload(payload)
|
||||
|
||||
for key in list(payload):
|
||||
for key in list(discovery_payload):
|
||||
abbreviated_key = key
|
||||
key = ABBREVIATIONS.get(key, key)
|
||||
payload[key] = payload.pop(abbreviated_key)
|
||||
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
|
||||
|
||||
if CONF_DEVICE in payload:
|
||||
device = payload[CONF_DEVICE]
|
||||
if CONF_DEVICE in discovery_payload:
|
||||
device = discovery_payload[CONF_DEVICE]
|
||||
for key in list(device):
|
||||
abbreviated_key = key
|
||||
key = DEVICE_ABBREVIATIONS.get(key, key)
|
||||
device[key] = device.pop(abbreviated_key)
|
||||
|
||||
if CONF_AVAILABILITY in payload:
|
||||
for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]):
|
||||
if CONF_AVAILABILITY in discovery_payload:
|
||||
for availability_conf in cv.ensure_list(
|
||||
discovery_payload[CONF_AVAILABILITY]
|
||||
):
|
||||
if isinstance(availability_conf, dict):
|
||||
for key in list(availability_conf):
|
||||
abbreviated_key = key
|
||||
key = ABBREVIATIONS.get(key, key)
|
||||
availability_conf[key] = availability_conf.pop(abbreviated_key)
|
||||
|
||||
if TOPIC_BASE in payload:
|
||||
base = payload.pop(TOPIC_BASE)
|
||||
for key, value in payload.items():
|
||||
if TOPIC_BASE in discovery_payload:
|
||||
base = discovery_payload.pop(TOPIC_BASE)
|
||||
for key, value in discovery_payload.items():
|
||||
if isinstance(value, str) and value:
|
||||
if value[0] == TOPIC_BASE and key.endswith("topic"):
|
||||
payload[key] = f"{base}{value[1:]}"
|
||||
discovery_payload[key] = f"{base}{value[1:]}"
|
||||
if value[-1] == TOPIC_BASE and key.endswith("topic"):
|
||||
payload[key] = f"{value[:-1]}{base}"
|
||||
if payload.get(CONF_AVAILABILITY):
|
||||
for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]):
|
||||
discovery_payload[key] = f"{value[:-1]}{base}"
|
||||
if discovery_payload.get(CONF_AVAILABILITY):
|
||||
for availability_conf in cv.ensure_list(
|
||||
discovery_payload[CONF_AVAILABILITY]
|
||||
):
|
||||
if not isinstance(availability_conf, dict):
|
||||
continue
|
||||
if topic := availability_conf.get(CONF_TOPIC):
|
||||
if topic := str(availability_conf.get(CONF_TOPIC)):
|
||||
if topic[0] == TOPIC_BASE:
|
||||
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
|
||||
if topic[-1] == TOPIC_BASE:
|
||||
|
@ -171,21 +175,25 @@ async def async_start( # noqa: C901
|
|||
discovery_id = " ".join((node_id, object_id)) if node_id else object_id
|
||||
discovery_hash = (component, discovery_id)
|
||||
|
||||
if payload:
|
||||
if discovery_payload:
|
||||
# Attach MQTT topic to the payload, used for debug prints
|
||||
setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')")
|
||||
setattr(
|
||||
discovery_payload,
|
||||
"__configuration_source__",
|
||||
f"MQTT (topic: '{topic}')",
|
||||
)
|
||||
discovery_data = {
|
||||
ATTR_DISCOVERY_HASH: discovery_hash,
|
||||
ATTR_DISCOVERY_PAYLOAD: payload,
|
||||
ATTR_DISCOVERY_PAYLOAD: discovery_payload,
|
||||
ATTR_DISCOVERY_TOPIC: topic,
|
||||
}
|
||||
setattr(payload, "discovery_data", discovery_data)
|
||||
setattr(discovery_payload, "discovery_data", discovery_data)
|
||||
|
||||
payload[CONF_PLATFORM] = "mqtt"
|
||||
discovery_payload[CONF_PLATFORM] = "mqtt"
|
||||
|
||||
if discovery_hash in mqtt_data.discovery_pending_discovered:
|
||||
pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
|
||||
pending.appendleft(payload)
|
||||
pending.appendleft(discovery_payload)
|
||||
_LOGGER.info(
|
||||
"Component has already been discovered: %s %s, queuing update",
|
||||
component,
|
||||
|
@ -193,7 +201,9 @@ async def async_start( # noqa: C901
|
|||
)
|
||||
return
|
||||
|
||||
await async_process_discovery_payload(component, discovery_id, payload)
|
||||
await async_process_discovery_payload(
|
||||
component, discovery_id, discovery_payload
|
||||
)
|
||||
|
||||
async def async_process_discovery_payload(
|
||||
component: str, discovery_id: str, payload: MQTTDiscoveryPayload
|
||||
|
@ -204,7 +214,7 @@ async def async_start( # noqa: C901
|
|||
discovery_hash = (component, discovery_id)
|
||||
if discovery_hash in mqtt_data.discovery_already_discovered or payload:
|
||||
|
||||
async def discovery_done(_) -> None:
|
||||
async def discovery_done(_: Any) -> None:
|
||||
pending = mqtt_data.discovery_pending_discovered[discovery_hash][
|
||||
"pending"
|
||||
]
|
||||
|
|
|
@ -680,7 +680,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
|
||||
|
||||
@property
|
||||
def assumed_state(self):
|
||||
def assumed_state(self) -> bool:
|
||||
"""Return true if we do optimistic updates."""
|
||||
return self._optimistic
|
||||
|
||||
|
|
|
@ -620,7 +620,8 @@ async def cleanup_device_registry(
|
|||
|
||||
def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]:
|
||||
"""Get the discovery hash from the discovery data."""
|
||||
return discovery_data[ATTR_DISCOVERY_HASH]
|
||||
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
|
||||
return discovery_hash
|
||||
|
||||
|
||||
def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None:
|
||||
|
@ -1113,7 +1114,7 @@ class MqttEntity(
|
|||
@property
|
||||
def entity_registry_enabled_default(self) -> bool:
|
||||
"""Return if the entity should be enabled when first added to the entity registry."""
|
||||
return self._config[CONF_ENABLED_BY_DEFAULT]
|
||||
return bool(self._config[CONF_ENABLED_BY_DEFAULT])
|
||||
|
||||
@property
|
||||
def entity_category(self) -> EntityCategory | None:
|
||||
|
|
|
@ -150,7 +150,7 @@ class MqttCommandTemplate:
|
|||
if self._entity:
|
||||
values[ATTR_ENTITY_ID] = self._entity.entity_id
|
||||
values[ATTR_NAME] = self._entity.name
|
||||
if not self._template_state:
|
||||
if not self._template_state and self._command_template.hass is not None:
|
||||
self._template_state = template.TemplateStateFromEntityId(
|
||||
self._entity.hass, self._entity.entity_id
|
||||
)
|
||||
|
@ -200,6 +200,8 @@ class MqttValueTemplate:
|
|||
variables: TemplateVarsType = None,
|
||||
) -> ReceivePayloadType:
|
||||
"""Render with possible json value or pass-though a received MQTT value."""
|
||||
rendered_payload: ReceivePayloadType
|
||||
|
||||
if self._value_template is None:
|
||||
return payload
|
||||
|
||||
|
@ -227,9 +229,12 @@ class MqttValueTemplate:
|
|||
values,
|
||||
self._value_template,
|
||||
)
|
||||
return self._value_template.async_render_with_possible_json_value(
|
||||
payload, variables=values
|
||||
rendered_payload = (
|
||||
self._value_template.async_render_with_possible_json_value(
|
||||
payload, variables=values
|
||||
)
|
||||
)
|
||||
return rendered_payload
|
||||
|
||||
_LOGGER.debug(
|
||||
"Rendering incoming payload '%s' with variables %s with default value '%s' and %s",
|
||||
|
@ -238,9 +243,10 @@ class MqttValueTemplate:
|
|||
default,
|
||||
self._value_template,
|
||||
)
|
||||
return self._value_template.async_render_with_possible_json_value(
|
||||
rendered_payload = self._value_template.async_render_with_possible_json_value(
|
||||
payload, default, variables=values
|
||||
)
|
||||
return rendered_payload
|
||||
|
||||
|
||||
class EntityTopicState:
|
||||
|
|
|
@ -19,7 +19,7 @@ class EntitySubscription:
|
|||
"""Class to hold data about an active entity topic subscription."""
|
||||
|
||||
hass: HomeAssistant = attr.ib()
|
||||
topic: str = attr.ib()
|
||||
topic: str | None = attr.ib()
|
||||
message_callback: MessageCallbackType = attr.ib()
|
||||
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib()
|
||||
unsubscribe_callback: Callable[[], None] | None = attr.ib()
|
||||
|
@ -39,7 +39,7 @@ class EntitySubscription:
|
|||
other.unsubscribe_callback()
|
||||
# Clear debug data if it exists
|
||||
debug_info.remove_subscription(
|
||||
self.hass, other.message_callback, other.topic
|
||||
self.hass, other.message_callback, str(other.topic)
|
||||
)
|
||||
|
||||
if self.topic is None:
|
||||
|
@ -112,7 +112,7 @@ def async_prepare_subscribe_topics(
|
|||
remaining.unsubscribe_callback()
|
||||
# Clear debug data if it exists
|
||||
debug_info.remove_subscription(
|
||||
hass, remaining.message_callback, remaining.topic
|
||||
hass, remaining.message_callback, str(remaining.topic)
|
||||
)
|
||||
|
||||
return new_state
|
||||
|
|
|
@ -97,7 +97,7 @@ def valid_subscribe_topic(topic: Any) -> str:
|
|||
|
||||
def valid_subscribe_topic_template(value: Any) -> template.Template:
|
||||
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
|
||||
tpl = template.Template(value)
|
||||
tpl = cv.template(value)
|
||||
|
||||
if tpl.is_static:
|
||||
valid_subscribe_topic(value)
|
||||
|
@ -115,7 +115,8 @@ def valid_publish_topic(topic: Any) -> str:
|
|||
|
||||
def valid_qos_schema(qos: Any) -> int:
|
||||
"""Validate that QOS value is valid."""
|
||||
return _VALID_QOS_SCHEMA(qos)
|
||||
validated_qos: int = _VALID_QOS_SCHEMA(qos)
|
||||
return validated_qos
|
||||
|
||||
|
||||
_MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
|
||||
|
@ -138,9 +139,12 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
|
|||
|
||||
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
|
||||
"""Return typed MqttData from hass.data[DATA_MQTT]."""
|
||||
mqtt_data: MqttData
|
||||
if ensure_exists:
|
||||
return hass.data.setdefault(DATA_MQTT, MqttData())
|
||||
return hass.data[DATA_MQTT]
|
||||
mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
|
||||
return mqtt_data
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
return mqtt_data
|
||||
|
||||
|
||||
async def async_create_certificate_temp_files(
|
||||
|
|
10
mypy.ini
10
mypy.ini
|
@ -1623,6 +1623,16 @@ disallow_untyped_defs = true
|
|||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.mqtt.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.mysensors.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
|
Loading…
Reference in New Issue