Enable type checks for device_tracker (#50805)

* Enable type checks for device_tracker

* Fix MQTT test
pull/50852/head
Ruslan Sayfutdinov 2021-05-19 09:36:26 +01:00 committed by GitHub
parent 4c7fcae536
commit 62386c8676
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 73 deletions

View File

@ -55,7 +55,7 @@ class ActiontecDeviceScanner(DeviceScanner):
self._update_info()
return [client.mac_address for client in self.last_results]
def get_device_name(self, device: str) -> str | None: # type: ignore[override]
def get_device_name(self, device: str) -> str | None:
"""Return the name of the given device or None if we don't know."""
for client in self.last_results:
if client.mac_address == device:

View File

@ -82,19 +82,19 @@ class TrackerEntity(BaseTrackerEntity):
return 0
@property
def location_name(self) -> str:
def location_name(self) -> str | None:
"""Return a location name for the current location of the device."""
return None
@property
def latitude(self) -> float:
def latitude(self) -> float | None:
"""Return latitude value of the device."""
return NotImplementedError
raise NotImplementedError
@property
def longitude(self) -> float:
def longitude(self) -> float | None:
"""Return longitude value of the device."""
return NotImplementedError
raise NotImplementedError
@property
def state(self):
@ -102,7 +102,7 @@ class TrackerEntity(BaseTrackerEntity):
if self.location_name:
return self.location_name
if self.latitude is not None:
if self.latitude is not None and self.longitude is not None:
zone_state = zone.async_active_zone(
self.hass, self.latitude, self.longitude, self.location_accuracy
)

View File

@ -309,7 +309,7 @@ async def async_create_platform_type(
def async_setup_scanner_platform(
hass: HomeAssistant,
config: ConfigType,
scanner: Any,
scanner: DeviceScanner,
async_see_device: Callable,
platform: str,
):
@ -324,7 +324,7 @@ def async_setup_scanner_platform(
# Initial scan of each mac we also tell about host name for config
seen: Any = set()
async def async_device_tracker_scan(now: dt_util.dt.datetime):
async def async_device_tracker_scan(now: dt_util.dt.datetime | None):
"""Handle interval matches."""
if update_lock.locked():
LOGGER.warning(
@ -424,21 +424,21 @@ class DeviceTracker:
def see(
self,
mac: str = None,
dev_id: str = None,
host_name: str = None,
location_name: str = None,
gps: GPSType = None,
gps_accuracy: int = None,
battery: int = None,
attributes: dict = None,
mac: str | None = None,
dev_id: str | None = None,
host_name: str | None = None,
location_name: str | None = None,
gps: GPSType | None = None,
gps_accuracy: int | None = None,
battery: int | None = None,
attributes: dict | None = None,
source_type: str = SOURCE_TYPE_GPS,
picture: str = None,
icon: str = None,
consider_home: timedelta = None,
picture: str | None = None,
icon: str | None = None,
consider_home: timedelta | None = None,
):
"""Notify the device tracker that you see a device."""
self.hass.add_job(
self.hass.create_task(
self.async_see(
mac,
dev_id,
@ -457,19 +457,19 @@ class DeviceTracker:
async def async_see(
self,
mac: str = None,
dev_id: str = None,
host_name: str = None,
location_name: str = None,
gps: GPSType = None,
gps_accuracy: int = None,
battery: int = None,
attributes: dict = None,
mac: str | None = None,
dev_id: str | None = None,
host_name: str | None = None,
location_name: str | None = None,
gps: GPSType | None = None,
gps_accuracy: int | None = None,
battery: int | None = None,
attributes: dict | None = None,
source_type: str = SOURCE_TYPE_GPS,
picture: str = None,
icon: str = None,
consider_home: timedelta = None,
):
picture: str | None = None,
icon: str | None = None,
consider_home: timedelta | None = None,
) -> None:
"""Notify the device tracker that you see a device.
This method is a coroutine.
@ -480,13 +480,13 @@ class DeviceTracker:
if mac is not None:
mac = str(mac).upper()
device = self.mac_to_dev.get(mac)
if not device:
if device is None:
dev_id = util.slugify(host_name or "") or util.slugify(mac)
else:
dev_id = cv.slug(str(dev_id).lower())
device = self.devices.get(dev_id)
if device:
if device is not None:
await device.async_seen(
host_name,
location_name,
@ -501,6 +501,9 @@ class DeviceTracker:
device.async_write_ha_state()
return
# If it's None then device is not None and we can't get here.
assert dev_id is not None
# Guard from calling see on entity registry entities.
entity_id = f"{DOMAIN}.{dev_id}"
if registry.async_is_registered(entity_id):
@ -598,15 +601,13 @@ class DeviceTracker:
class Device(RestoreEntity):
"""Base class for a tracked device."""
host_name: str = None
location_name: str = None
gps: GPSType = None
host_name: str | None = None
location_name: str | None = None
gps: GPSType | None = None
gps_accuracy: int = 0
last_seen: dt_util.dt.datetime = None
consider_home: dt_util.dt.timedelta = None
battery: int = None
attributes: dict = None
icon: str = None
last_seen: dt_util.dt.datetime | None = None
battery: int | None = None
attributes: dict | None = None
# Track if the last update of this device was HOME.
last_update_home = False
@ -618,11 +619,11 @@ class Device(RestoreEntity):
consider_home: timedelta,
track: bool,
dev_id: str,
mac: str,
name: str = None,
picture: str = None,
gravatar: str = None,
icon: str = None,
mac: str | None,
name: str | None = None,
picture: str | None = None,
gravatar: str | None = None,
icon: str | None = None,
) -> None:
"""Initialize a device."""
self.hass = hass
@ -648,11 +649,11 @@ class Device(RestoreEntity):
else:
self.config_picture = picture
self.icon = icon
self._icon = icon
self.source_type = None
self.source_type: str | None = None
self._attributes = {}
self._attributes: dict[str, Any] = {}
@property
def name(self):
@ -686,21 +687,26 @@ class Device(RestoreEntity):
return attributes
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return device state attributes."""
return self._attributes
@property
def icon(self) -> str | None:
"""Return device icon."""
return self._icon
async def async_seen(
self,
host_name: str = None,
location_name: str = None,
gps: GPSType = None,
gps_accuracy=0,
battery: int = None,
attributes: dict = None,
host_name: str | None = None,
location_name: str | None = None,
gps: GPSType | None = None,
gps_accuracy: int | None = None,
battery: int | None = None,
attributes: dict[str, Any] | None = None,
source_type: str = SOURCE_TYPE_GPS,
consider_home: timedelta = None,
):
consider_home: timedelta | None = None,
) -> None:
"""Mark the device as seen."""
self.source_type = source_type
self.last_seen = dt_util.utcnow()
@ -708,9 +714,9 @@ class Device(RestoreEntity):
self.location_name = location_name
self.consider_home = consider_home or self.consider_home
if battery:
if battery is not None:
self.battery = battery
if attributes:
if attributes is not None:
self._attributes.update(attributes)
self.gps = None
@ -726,7 +732,7 @@ class Device(RestoreEntity):
await self.async_update()
def stale(self, now: dt_util.dt.datetime = None):
def stale(self, now: dt_util.dt.datetime | None = None) -> bool:
"""Return if device state is stale.
Async friendly.
@ -795,7 +801,7 @@ class Device(RestoreEntity):
class DeviceScanner:
"""Device scanner object."""
hass: HomeAssistant = None
hass: HomeAssistant | None = None
def scan_devices(self) -> list[str]:
"""Scan for devices."""
@ -803,14 +809,20 @@ class DeviceScanner:
async def async_scan_devices(self) -> Any:
"""Scan for devices."""
assert (
self.hass is not None
), "hass should be set by async_setup_scanner_platform"
return await self.hass.async_add_executor_job(self.scan_devices)
def get_device_name(self, device: str) -> str:
def get_device_name(self, device: str) -> str | None:
"""Get the name of a device."""
raise NotImplementedError()
async def async_get_device_name(self, device: str) -> Any:
async def async_get_device_name(self, device: str) -> str | None:
"""Get the name of a device."""
assert (
self.hass is not None
), "hass should be set by async_setup_scanner_platform"
return await self.hass.async_add_executor_job(self.get_device_name, device)
def get_extra_attributes(self, device: str) -> dict:
@ -819,6 +831,9 @@ class DeviceScanner:
async def async_get_extra_attributes(self, device: str) -> Any:
"""Get the extra attributes of a device."""
assert (
self.hass is not None
), "hass should be set by async_setup_scanner_platform"
return await self.hass.async_add_executor_job(self.get_extra_attributes, device)
@ -868,7 +883,7 @@ async def async_load_config(path: str, hass: HomeAssistant, consider_home: timed
def update_config(path: str, dev_id: str, device: Device):
"""Add device to YAML configuration file."""
with open(path, "a") as out:
device = {
device_config = {
device.dev_id: {
ATTR_NAME: device.name,
ATTR_MAC: device.mac,
@ -878,7 +893,7 @@ def update_config(path: str, dev_id: str, device: Device):
}
}
out.write("\n")
out.write(dump(device))
out.write(dump(device_config))
def get_gravatar_for_email(email: str):

View File

@ -816,9 +816,6 @@ ignore_errors = true
[mypy-homeassistant.components.denonavr.*]
ignore_errors = true
[mypy-homeassistant.components.device_tracker.*]
ignore_errors = true
[mypy-homeassistant.components.devolo_home_control.*]
ignore_errors = true

View File

@ -44,7 +44,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.deconz.*",
"homeassistant.components.demo.*",
"homeassistant.components.denonavr.*",
"homeassistant.components.device_tracker.*",
"homeassistant.components.devolo_home_control.*",
"homeassistant.components.dhcp.*",
"homeassistant.components.directv.*",

View File

@ -359,4 +359,4 @@ async def test_setting_device_tracker_location_via_lat_lon_message(
async_fire_mqtt_message(hass, "attributes-topic", '{"latitude":32.87336}')
state = hass.states.get("device_tracker.test")
assert state.attributes["latitude"] == 32.87336
assert state.state == STATE_NOT_HOME
assert state.state == STATE_UNKNOWN