From ca20b0cf17aef4e080500e13287a021598c8d4dd Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 2 Jun 2019 13:57:21 -0700 Subject: [PATCH] Add restore state to OwnTracks device tracker (#24256) * Add restore state to OwnTracks device tracker * Lint * Also store entity devices * Update test_device_tracker.py --- .../components/owntracks/device_tracker.py | 136 ++++++++++-------- .../owntracks/test_device_tracker.py | 44 ++++++ 2 files changed, 124 insertions(+), 56 deletions(-) diff --git a/homeassistant/components/owntracks/device_tracker.py b/homeassistant/components/owntracks/device_tracker.py index fb9fedf26fa..d74fea43c29 100644 --- a/homeassistant/components/owntracks/device_tracker.py +++ b/homeassistant/components/owntracks/device_tracker.py @@ -2,10 +2,19 @@ import logging from homeassistant.core import callback -from homeassistant.components.device_tracker.const import ENTITY_ID_FORMAT +from homeassistant.const import ( + ATTR_GPS_ACCURACY, + ATTR_LATITUDE, + ATTR_LONGITUDE, + ATTR_BATTERY_LEVEL, +) +from homeassistant.components.device_tracker.const import ( + ENTITY_ID_FORMAT, ATTR_SOURCE_TYPE) from homeassistant.components.device_tracker.config_entry import ( DeviceTrackerEntity ) +from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers import device_registry from . import DOMAIN as OT_DOMAIN _LOGGER = logging.getLogger(__name__) @@ -14,53 +23,52 @@ _LOGGER = logging.getLogger(__name__) async def async_setup_entry(hass, entry, async_add_entities): """Set up OwnTracks based off an entry.""" @callback - def _receive_data(dev_id, host_name, gps, attributes, gps_accuracy=None, - battery=None, source_type=None, location_name=None): + def _receive_data(dev_id, **data): """Receive set location.""" - device = hass.data[OT_DOMAIN]['devices'].get(dev_id) + entity = hass.data[OT_DOMAIN]['devices'].get(dev_id) - if device is not None: - device.update_data( - host_name=host_name, - gps=gps, - attributes=attributes, - gps_accuracy=gps_accuracy, - battery=battery, - source_type=source_type, - location_name=location_name, - ) + if entity is not None: + entity.update_data(data) return - device = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity( - dev_id=dev_id, - host_name=host_name, - gps=gps, - attributes=attributes, - gps_accuracy=gps_accuracy, - battery=battery, - source_type=source_type, - location_name=location_name, + entity = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity( + dev_id, data ) - async_add_entities([device]) + async_add_entities([entity]) hass.data[OT_DOMAIN]['context'].async_see = _receive_data + + # Restore previously loaded devices + dev_reg = await device_registry.async_get_registry(hass) + dev_ids = { + identifier[1] + for device in dev_reg.devices.values() + for identifier in device.identifiers + if identifier[0] == OT_DOMAIN + } + + if not dev_ids: + return True + + entities = [] + for dev_id in dev_ids: + entity = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity( + dev_id + ) + entities.append(entity) + + async_add_entities(entities) + return True -class OwnTracksEntity(DeviceTrackerEntity): +class OwnTracksEntity(DeviceTrackerEntity, RestoreEntity): """Represent a tracked device.""" - def __init__(self, dev_id, host_name, gps, attributes, gps_accuracy, - battery, source_type, location_name): + def __init__(self, dev_id, data=None): """Set up OwnTracks entity.""" self._dev_id = dev_id - self._host_name = host_name - self._gps = gps - self._gps_accuracy = gps_accuracy - self._location_name = location_name - self._attributes = attributes - self._battery = battery - self._source_type = source_type + self._data = data or {} self.entity_id = ENTITY_ID_FORMAT.format(dev_id) @property @@ -71,43 +79,45 @@ class OwnTracksEntity(DeviceTrackerEntity): @property def battery_level(self): """Return the battery level of the device.""" - return self._battery + return self._data.get('battery') @property def device_state_attributes(self): """Return device specific attributes.""" - return self._attributes + return self._data.get('attributes') @property def location_accuracy(self): """Return the gps accuracy of the device.""" - return self._gps_accuracy + return self._data.get('gps_accuracy') @property def latitude(self): """Return latitude value of the device.""" - if self._gps is not None: - return self._gps[0] + # Check with "get" instead of "in" because value can be None + if self._data.get('gps'): + return self._data['gps'][0] return None @property def longitude(self): """Return longitude value of the device.""" - if self._gps is not None: - return self._gps[1] + # Check with "get" instead of "in" because value can be None + if self._data.get('gps'): + return self._data['gps'][1] return None @property def location_name(self): """Return a location name for the current location of the device.""" - return self._location_name + return self._data.get('location_name') @property def name(self): """Return the name of the device.""" - return self._host_name + return self._data.get('host_name') @property def should_poll(self): @@ -117,26 +127,40 @@ class OwnTracksEntity(DeviceTrackerEntity): @property def source_type(self): """Return the source type, eg gps or router, of the device.""" - return self._source_type + return self._data.get('source_type') @property def device_info(self): """Return the device info.""" return { - 'name': self._host_name, + 'name': self.name, 'identifiers': {(OT_DOMAIN, self._dev_id)}, } - @callback - def update_data(self, host_name, gps, attributes, gps_accuracy, - battery, source_type, location_name): - """Mark the device as seen.""" - self._host_name = host_name - self._gps = gps - self._gps_accuracy = gps_accuracy - self._location_name = location_name - self._attributes = attributes - self._battery = battery - self._source_type = source_type + async def async_added_to_hass(self): + """Call when entity about to be added to Home Assistant.""" + await super().async_added_to_hass() + # Don't restore if we got set up with data. + if self._data: + return + + state = await self.async_get_last_state() + + if state is None: + return + + attr = state.attributes + self._data = { + 'host_name': state.name, + 'gps': (attr[ATTR_LATITUDE], attr[ATTR_LONGITUDE]), + 'gps_accuracy': attr[ATTR_GPS_ACCURACY], + 'battery': attr[ATTR_BATTERY_LEVEL], + 'source_type': attr[ATTR_SOURCE_TYPE], + } + + @callback + def update_data(self, data): + """Mark the device as seen.""" + self._data = data self.async_write_ha_state() diff --git a/tests/components/owntracks/test_device_tracker.py b/tests/components/owntracks/test_device_tracker.py index b81f434a2c1..7d8d48de586 100644 --- a/tests/components/owntracks/test_device_tracker.py +++ b/tests/components/owntracks/test_device_tracker.py @@ -1491,3 +1491,47 @@ async def test_region_mapping(hass, setup_comp): await send_message(hass, EVENT_TOPIC, message) assert_location_state(hass, 'inner') + + +async def test_restore_state(hass, hass_client): + """Test that we can restore state.""" + entry = MockConfigEntry(domain='owntracks', data={ + 'webhook_id': 'owntracks_test', + 'secret': 'abcd', + }) + entry.add_to_hass(hass) + + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + client = await hass_client() + resp = await client.post( + '/api/webhook/owntracks_test', + json=LOCATION_MESSAGE, + headers={ + 'X-Limit-u': 'Paulus', + 'X-Limit-d': 'Pixel', + } + ) + assert resp.status == 200 + await hass.async_block_till_done() + + state_1 = hass.states.get('device_tracker.paulus_pixel') + assert state_1 is not None + + await hass.config_entries.async_reload(entry.entry_id) + await hass.async_block_till_done() + + state_2 = hass.states.get('device_tracker.paulus_pixel') + assert state_2 is not None + + assert state_1 is not state_2 + + assert state_1.state == state_2.state + assert state_1.name == state_2.name + assert state_1.attributes['latitude'] == state_2.attributes['latitude'] + assert state_1.attributes['longitude'] == state_2.attributes['longitude'] + assert state_1.attributes['battery_level'] == \ + state_2.attributes['battery_level'] + assert state_1.attributes['source_type'] == \ + state_2.attributes['source_type']