Improve mqtt MessageCallback typing (#75614)

* Improve mqtt MessageCallback typing

* Use MQTTMessage
pull/77704/head
Marc Mueller 2022-07-26 03:04:19 +02:00 committed by GitHub
parent 9c725bc106
commit 2b617e3885
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 8 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Iterable
from collections.abc import Awaitable, Callable, Coroutine, Iterable
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
@ -15,6 +15,7 @@ import uuid
import attr
import certifi
from paho.mqtt.client import MQTTMessage
from homeassistant.const import (
CONF_CLIENT_ID,
@ -246,7 +247,7 @@ class Subscription:
topic: str = attr.ib()
matcher: Any = attr.ib()
job: HassJob = attr.ib()
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None] = attr.ib()
qos: int = attr.ib(default=0)
encoding: str | None = attr.ib(default="utf-8")
@ -444,7 +445,7 @@ class MQTT:
async def async_subscribe(
self,
topic: str,
msg_callback: MessageCallbackType,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
qos: int,
encoding: str | None = None,
) -> Callable[[], None]:
@ -597,15 +598,15 @@ class MQTT:
self.hass.add_job(self._mqtt_handle_message, msg)
@lru_cache(2048)
def _matching_subscriptions(self, topic):
subscriptions = []
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
subscriptions: list[Subscription] = []
for subscription in self.subscriptions:
if subscription.matcher(topic):
subscriptions.append(subscription)
return subscriptions
@callback
def _mqtt_handle_message(self, msg) -> None:
def _mqtt_handle_message(self, msg: MQTTMessage) -> None:
_LOGGER.debug(
"Received message on %s%s: %s",
msg.topic,

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from ast import literal_eval
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
import datetime as dt
from typing import Any, Union
@ -42,7 +42,7 @@ class ReceiveMessage:
timestamp: dt.datetime = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]]
AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
MessageCallbackType = Callable[[ReceiveMessage], None]