From bcae6d604e2967c7475f0caa4b1b5e4e76ab88bf Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 28 Oct 2022 18:20:33 +0200 Subject: [PATCH] Improve MQTT type hints part 8 (#81034) * Improve typing device_tracker discovery * Improve typing device_tracker yaml * Add test source_type attribute * Follow up comment * Initialize at `__init__` not at class level. * Use full name for return variable * Correct import, remove assert * Use AsyncSeeCallback --- .../mqtt/device_tracker/schema_discovery.py | 62 +++++++++++-------- .../mqtt/device_tracker/schema_yaml.py | 33 ++++++---- .../mqtt/test_device_tracker_discovery.py | 4 +- 3 files changed, 60 insertions(+), 39 deletions(-) diff --git a/homeassistant/components/mqtt/device_tracker/schema_discovery.py b/homeassistant/components/mqtt/device_tracker/schema_discovery.py index 673a9cb04b7..ba088a59c44 100644 --- a/homeassistant/components/mqtt/device_tracker/schema_discovery.py +++ b/homeassistant/components/mqtt/device_tracker/schema_discovery.py @@ -1,13 +1,14 @@ """Support for tracking MQTT enabled devices identified through discovery.""" from __future__ import annotations +from collections.abc import Callable import functools import voluptuous as vol from homeassistant.components import device_tracker -from homeassistant.components.device_tracker import SOURCE_TYPES -from homeassistant.components.device_tracker.config_entry import ( +from homeassistant.components.device_tracker import ( + SOURCE_TYPES, SourceType, TrackerEntity, ) @@ -24,14 +25,14 @@ from homeassistant.const import ( from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .. import subscription from ..config import MQTT_RO_SCHEMA from ..const import CONF_QOS, CONF_STATE_TOPIC from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper -from ..models import MqttValueTemplate +from ..models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType from ..util import get_mqtt_data CONF_PAYLOAD_HOME = "payload_home" @@ -70,8 +71,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT Device Tracker entity.""" async_add_entities([MqttDeviceTracker(hass, config, config_entry, discovery_data)]) @@ -81,37 +82,44 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): """Representation of a device tracker using MQTT.""" _entity_id_format = device_tracker.ENTITY_ID_FORMAT + _value_template: Callable[..., ReceivePayloadType] - def __init__(self, hass, config, config_entry, discovery_data): + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the tracker.""" - self._location_name = None - + self._location_name: str | None = None MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" self._value_template = MqttValueTemplate( - self._config.get(CONF_VALUE_TEMPLATE), entity=self + config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def message_received(msg): + def message_received(msg: ReceiveMessage) -> None: """Handle new MQTT messages.""" - payload = self._value_template(msg.payload) + payload: ReceivePayloadType = self._value_template(msg.payload) if payload == self._config[CONF_PAYLOAD_HOME]: self._location_name = STATE_HOME elif payload == self._config[CONF_PAYLOAD_NOT_HOME]: self._location_name = STATE_NOT_HOME else: + assert isinstance(msg.payload, str) self._location_name = msg.payload get_mqtt_data(self.hass).state_write_requests.write_state_request(self) @@ -128,46 +136,50 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) @property - def latitude(self): + def latitude(self) -> float | None: """Return latitude if provided in extra_state_attributes or None.""" if ( self.extra_state_attributes is not None and ATTR_LATITUDE in self.extra_state_attributes ): - return self.extra_state_attributes[ATTR_LATITUDE] + latitude: float = self.extra_state_attributes[ATTR_LATITUDE] + return latitude return None @property - def location_accuracy(self): + def location_accuracy(self) -> int: """Return location accuracy if provided in extra_state_attributes or None.""" if ( self.extra_state_attributes is not None and ATTR_GPS_ACCURACY in self.extra_state_attributes ): - return self.extra_state_attributes[ATTR_GPS_ACCURACY] - return None + accuracy: int = self.extra_state_attributes[ATTR_GPS_ACCURACY] + return accuracy + return 0 @property - def longitude(self): + def longitude(self) -> float | None: """Return longitude if provided in extra_state_attributes or None.""" if ( self.extra_state_attributes is not None and ATTR_LONGITUDE in self.extra_state_attributes ): - return self.extra_state_attributes[ATTR_LONGITUDE] + longitude: float = self.extra_state_attributes[ATTR_LONGITUDE] + return longitude return None @property - def location_name(self): + def location_name(self) -> str | None: """Return a location name for the current location of the device.""" return self._location_name @property def source_type(self) -> SourceType | str: """Return the source type, eg gps or router, of the device.""" - return self._config[CONF_SOURCE_TYPE] + source_type: SourceType | str = self._config[CONF_SOURCE_TYPE] + return source_type diff --git a/homeassistant/components/mqtt/device_tracker/schema_yaml.py b/homeassistant/components/mqtt/device_tracker/schema_yaml.py index c005a82dbeb..d88a82e3002 100644 --- a/homeassistant/components/mqtt/device_tracker/schema_yaml.py +++ b/homeassistant/components/mqtt/device_tracker/schema_yaml.py @@ -1,14 +1,19 @@ """Support for tracking MQTT enabled devices defined in YAML.""" from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine import dataclasses import logging from typing import Any import voluptuous as vol -from homeassistant.components.device_tracker import PLATFORM_SCHEMA, SOURCE_TYPES +from homeassistant.components.device_tracker import ( + PLATFORM_SCHEMA, + SOURCE_TYPES, + AsyncSeeCallback, + SourceType, +) from homeassistant.const import CONF_DEVICES, STATE_HOME, STATE_NOT_HOME from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv @@ -18,6 +23,7 @@ from ... import mqtt from ..client import async_subscribe from ..config import SCHEMA_BASE from ..const import CONF_QOS, MQTT_DATA_DEVICE_TRACKER_LEGACY +from ..models import ReceiveMessage from ..util import mqtt_config_entry_enabled, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -40,22 +46,22 @@ PLATFORM_SCHEMA_YAML = PLATFORM_SCHEMA.extend(SCHEMA_BASE).extend( class MQTTLegacyDeviceTrackerData: """Class to hold device tracker data.""" - async_see: Callable[..., Awaitable[None]] + async_see: Callable[..., Coroutine[Any, Any, None]] config: ConfigType async def async_setup_scanner_from_yaml( hass: HomeAssistant, config: ConfigType, - async_see: Callable[..., Awaitable[None]], + async_see: AsyncSeeCallback, discovery_info: DiscoveryInfoType | None = None, ) -> bool: """Set up the MQTT tracker.""" - devices = config[CONF_DEVICES] - qos = config[CONF_QOS] - payload_home = config[CONF_PAYLOAD_HOME] - payload_not_home = config[CONF_PAYLOAD_NOT_HOME] - source_type = config.get(CONF_SOURCE_TYPE) + devices: dict[str, str] = config[CONF_DEVICES] + qos: int = config[CONF_QOS] + payload_home: str = config[CONF_PAYLOAD_HOME] + payload_not_home: str = config[CONF_PAYLOAD_NOT_HOME] + source_type: SourceType | str | None = config.get(CONF_SOURCE_TYPE) config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] subscriptions: list[Callable] = [] @@ -78,16 +84,19 @@ async def async_setup_scanner_from_yaml( for dev_id, topic in devices.items(): @callback - def async_message_received(msg, dev_id=dev_id): + def async_message_received(msg: ReceiveMessage, dev_id: str = dev_id) -> None: """Handle received MQTT message.""" if msg.payload == payload_home: location_name = STATE_HOME elif msg.payload == payload_not_home: location_name = STATE_NOT_HOME else: - location_name = msg.payload + location_name = str(msg.payload) - see_args = {"dev_id": dev_id, "location_name": location_name} + see_args: dict[str, Any] = { + "dev_id": dev_id, + "location_name": location_name, + } if source_type: see_args["source_type"] = source_type diff --git a/tests/components/mqtt/test_device_tracker_discovery.py b/tests/components/mqtt/test_device_tracker_discovery.py index 6324afeb3ef..873ca1ed9a1 100644 --- a/tests/components/mqtt/test_device_tracker_discovery.py +++ b/tests/components/mqtt/test_device_tracker_discovery.py @@ -410,12 +410,12 @@ async def test_setting_device_tracker_location_via_lat_lon_message( async_fire_mqtt_message( hass, "attributes-topic", - '{"latitude":50.1,"longitude": -2.1, "gps_accuracy":1.5}', + '{"latitude":50.1,"longitude": -2.1}', ) state = hass.states.get("device_tracker.test") assert state.attributes["latitude"] == 50.1 assert state.attributes["longitude"] == -2.1 - assert state.attributes["gps_accuracy"] == 1.5 + assert state.attributes["gps_accuracy"] == 0 assert state.state == STATE_NOT_HOME async_fire_mqtt_message(hass, "attributes-topic", '{"longitude": -117.22743}')