Improve MQTT type hints part 8 (#81034)

* Improve typing device_tracker discovery

* Improve typing device_tracker yaml

* Add test source_type attribute

* Follow up comment

* Initialize at `__init__` not at class level.

* Use full name for return variable

* Correct import, remove assert

* Use AsyncSeeCallback
pull/81144/head^2
Jan Bouwhuis 2022-10-28 18:20:33 +02:00 committed by GitHub
parent 2214fff3b4
commit bcae6d604e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 39 deletions

View File

@ -1,13 +1,14 @@
"""Support for tracking MQTT enabled devices identified through discovery.""" """Support for tracking MQTT enabled devices identified through discovery."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
import voluptuous as vol import voluptuous as vol
from homeassistant.components import device_tracker from homeassistant.components import device_tracker
from homeassistant.components.device_tracker import SOURCE_TYPES from homeassistant.components.device_tracker import (
from homeassistant.components.device_tracker.config_entry import ( SOURCE_TYPES,
SourceType, SourceType,
TrackerEntity, TrackerEntity,
) )
@ -24,14 +25,14 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .. import subscription from .. import subscription
from ..config import MQTT_RO_SCHEMA from ..config import MQTT_RO_SCHEMA
from ..const import CONF_QOS, CONF_STATE_TOPIC from ..const import CONF_QOS, CONF_STATE_TOPIC
from ..debug_info import log_messages from ..debug_info import log_messages
from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, async_setup_entry_helper
from ..models import MqttValueTemplate from ..models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
from ..util import get_mqtt_data from ..util import get_mqtt_data
CONF_PAYLOAD_HOME = "payload_home" CONF_PAYLOAD_HOME = "payload_home"
@ -70,8 +71,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT Device Tracker entity.""" """Set up the MQTT Device Tracker entity."""
async_add_entities([MqttDeviceTracker(hass, config, config_entry, discovery_data)]) async_add_entities([MqttDeviceTracker(hass, config, config_entry, discovery_data)])
@ -81,37 +82,44 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
"""Representation of a device tracker using MQTT.""" """Representation of a device tracker using MQTT."""
_entity_id_format = device_tracker.ENTITY_ID_FORMAT _entity_id_format = device_tracker.ENTITY_ID_FORMAT
_value_template: Callable[..., ReceivePayloadType]
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the tracker.""" """Initialize the tracker."""
self._location_name = None self._location_name: str | None = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._value_template = MqttValueTemplate( self._value_template = MqttValueTemplate(
self._config.get(CONF_VALUE_TEMPLATE), entity=self config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def message_received(msg): def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = self._value_template(msg.payload) payload: ReceivePayloadType = self._value_template(msg.payload)
if payload == self._config[CONF_PAYLOAD_HOME]: if payload == self._config[CONF_PAYLOAD_HOME]:
self._location_name = STATE_HOME self._location_name = STATE_HOME
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]: elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
self._location_name = STATE_NOT_HOME self._location_name = STATE_NOT_HOME
else: else:
assert isinstance(msg.payload, str)
self._location_name = msg.payload self._location_name = msg.payload
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
@ -128,46 +136,50 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@property @property
def latitude(self): def latitude(self) -> float | None:
"""Return latitude if provided in extra_state_attributes or None.""" """Return latitude if provided in extra_state_attributes or None."""
if ( if (
self.extra_state_attributes is not None self.extra_state_attributes is not None
and ATTR_LATITUDE in self.extra_state_attributes and ATTR_LATITUDE in self.extra_state_attributes
): ):
return self.extra_state_attributes[ATTR_LATITUDE] latitude: float = self.extra_state_attributes[ATTR_LATITUDE]
return latitude
return None return None
@property @property
def location_accuracy(self): def location_accuracy(self) -> int:
"""Return location accuracy if provided in extra_state_attributes or None.""" """Return location accuracy if provided in extra_state_attributes or None."""
if ( if (
self.extra_state_attributes is not None self.extra_state_attributes is not None
and ATTR_GPS_ACCURACY in self.extra_state_attributes and ATTR_GPS_ACCURACY in self.extra_state_attributes
): ):
return self.extra_state_attributes[ATTR_GPS_ACCURACY] accuracy: int = self.extra_state_attributes[ATTR_GPS_ACCURACY]
return None return accuracy
return 0
@property @property
def longitude(self): def longitude(self) -> float | None:
"""Return longitude if provided in extra_state_attributes or None.""" """Return longitude if provided in extra_state_attributes or None."""
if ( if (
self.extra_state_attributes is not None self.extra_state_attributes is not None
and ATTR_LONGITUDE in self.extra_state_attributes and ATTR_LONGITUDE in self.extra_state_attributes
): ):
return self.extra_state_attributes[ATTR_LONGITUDE] longitude: float = self.extra_state_attributes[ATTR_LONGITUDE]
return longitude
return None return None
@property @property
def location_name(self): def location_name(self) -> str | None:
"""Return a location name for the current location of the device.""" """Return a location name for the current location of the device."""
return self._location_name return self._location_name
@property @property
def source_type(self) -> SourceType | str: def source_type(self) -> SourceType | str:
"""Return the source type, eg gps or router, of the device.""" """Return the source type, eg gps or router, of the device."""
return self._config[CONF_SOURCE_TYPE] source_type: SourceType | str = self._config[CONF_SOURCE_TYPE]
return source_type

View File

@ -1,14 +1,19 @@
"""Support for tracking MQTT enabled devices defined in YAML.""" """Support for tracking MQTT enabled devices defined in YAML."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable from collections.abc import Callable, Coroutine
import dataclasses import dataclasses
import logging import logging
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components.device_tracker import PLATFORM_SCHEMA, SOURCE_TYPES from homeassistant.components.device_tracker import (
PLATFORM_SCHEMA,
SOURCE_TYPES,
AsyncSeeCallback,
SourceType,
)
from homeassistant.const import CONF_DEVICES, STATE_HOME, STATE_NOT_HOME from homeassistant.const import CONF_DEVICES, STATE_HOME, STATE_NOT_HOME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
@ -18,6 +23,7 @@ from ... import mqtt
from ..client import async_subscribe from ..client import async_subscribe
from ..config import SCHEMA_BASE from ..config import SCHEMA_BASE
from ..const import CONF_QOS, MQTT_DATA_DEVICE_TRACKER_LEGACY from ..const import CONF_QOS, MQTT_DATA_DEVICE_TRACKER_LEGACY
from ..models import ReceiveMessage
from ..util import mqtt_config_entry_enabled, valid_subscribe_topic from ..util import mqtt_config_entry_enabled, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -40,22 +46,22 @@ PLATFORM_SCHEMA_YAML = PLATFORM_SCHEMA.extend(SCHEMA_BASE).extend(
class MQTTLegacyDeviceTrackerData: class MQTTLegacyDeviceTrackerData:
"""Class to hold device tracker data.""" """Class to hold device tracker data."""
async_see: Callable[..., Awaitable[None]] async_see: Callable[..., Coroutine[Any, Any, None]]
config: ConfigType config: ConfigType
async def async_setup_scanner_from_yaml( async def async_setup_scanner_from_yaml(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
async_see: Callable[..., Awaitable[None]], async_see: AsyncSeeCallback,
discovery_info: DiscoveryInfoType | None = None, discovery_info: DiscoveryInfoType | None = None,
) -> bool: ) -> bool:
"""Set up the MQTT tracker.""" """Set up the MQTT tracker."""
devices = config[CONF_DEVICES] devices: dict[str, str] = config[CONF_DEVICES]
qos = config[CONF_QOS] qos: int = config[CONF_QOS]
payload_home = config[CONF_PAYLOAD_HOME] payload_home: str = config[CONF_PAYLOAD_HOME]
payload_not_home = config[CONF_PAYLOAD_NOT_HOME] payload_not_home: str = config[CONF_PAYLOAD_NOT_HOME]
source_type = config.get(CONF_SOURCE_TYPE) source_type: SourceType | str | None = config.get(CONF_SOURCE_TYPE)
config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
subscriptions: list[Callable] = [] subscriptions: list[Callable] = []
@ -78,16 +84,19 @@ async def async_setup_scanner_from_yaml(
for dev_id, topic in devices.items(): for dev_id, topic in devices.items():
@callback @callback
def async_message_received(msg, dev_id=dev_id): def async_message_received(msg: ReceiveMessage, dev_id: str = dev_id) -> None:
"""Handle received MQTT message.""" """Handle received MQTT message."""
if msg.payload == payload_home: if msg.payload == payload_home:
location_name = STATE_HOME location_name = STATE_HOME
elif msg.payload == payload_not_home: elif msg.payload == payload_not_home:
location_name = STATE_NOT_HOME location_name = STATE_NOT_HOME
else: else:
location_name = msg.payload location_name = str(msg.payload)
see_args = {"dev_id": dev_id, "location_name": location_name} see_args: dict[str, Any] = {
"dev_id": dev_id,
"location_name": location_name,
}
if source_type: if source_type:
see_args["source_type"] = source_type see_args["source_type"] = source_type

View File

@ -410,12 +410,12 @@ async def test_setting_device_tracker_location_via_lat_lon_message(
async_fire_mqtt_message( async_fire_mqtt_message(
hass, hass,
"attributes-topic", "attributes-topic",
'{"latitude":50.1,"longitude": -2.1, "gps_accuracy":1.5}', '{"latitude":50.1,"longitude": -2.1}',
) )
state = hass.states.get("device_tracker.test") state = hass.states.get("device_tracker.test")
assert state.attributes["latitude"] == 50.1 assert state.attributes["latitude"] == 50.1
assert state.attributes["longitude"] == -2.1 assert state.attributes["longitude"] == -2.1
assert state.attributes["gps_accuracy"] == 1.5 assert state.attributes["gps_accuracy"] == 0
assert state.state == STATE_NOT_HOME assert state.state == STATE_NOT_HOME
async_fire_mqtt_message(hass, "attributes-topic", '{"longitude": -117.22743}') async_fire_mqtt_message(hass, "attributes-topic", '{"longitude": -117.22743}')