Add restore state to OwnTracks device tracker (#24256)

* Add restore state to OwnTracks device tracker

* Lint

* Also store entity devices

* Update test_device_tracker.py
pull/24268/head
Paulus Schoutsen 2019-06-02 13:57:21 -07:00 committed by GitHub
parent 05454b76a6
commit ca20b0cf17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 56 deletions

View File

@ -2,10 +2,19 @@
import logging import logging
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components.device_tracker.const import ENTITY_ID_FORMAT from homeassistant.const import (
ATTR_GPS_ACCURACY,
ATTR_LATITUDE,
ATTR_LONGITUDE,
ATTR_BATTERY_LEVEL,
)
from homeassistant.components.device_tracker.const import (
ENTITY_ID_FORMAT, ATTR_SOURCE_TYPE)
from homeassistant.components.device_tracker.config_entry import ( from homeassistant.components.device_tracker.config_entry import (
DeviceTrackerEntity DeviceTrackerEntity
) )
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers import device_registry
from . import DOMAIN as OT_DOMAIN from . import DOMAIN as OT_DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -14,53 +23,52 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass, entry, async_add_entities): async def async_setup_entry(hass, entry, async_add_entities):
"""Set up OwnTracks based off an entry.""" """Set up OwnTracks based off an entry."""
@callback @callback
def _receive_data(dev_id, host_name, gps, attributes, gps_accuracy=None, def _receive_data(dev_id, **data):
battery=None, source_type=None, location_name=None):
"""Receive set location.""" """Receive set location."""
device = hass.data[OT_DOMAIN]['devices'].get(dev_id) entity = hass.data[OT_DOMAIN]['devices'].get(dev_id)
if device is not None: if entity is not None:
device.update_data( entity.update_data(data)
host_name=host_name,
gps=gps,
attributes=attributes,
gps_accuracy=gps_accuracy,
battery=battery,
source_type=source_type,
location_name=location_name,
)
return return
device = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity( entity = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity(
dev_id=dev_id, dev_id, data
host_name=host_name,
gps=gps,
attributes=attributes,
gps_accuracy=gps_accuracy,
battery=battery,
source_type=source_type,
location_name=location_name,
) )
async_add_entities([device]) async_add_entities([entity])
hass.data[OT_DOMAIN]['context'].async_see = _receive_data hass.data[OT_DOMAIN]['context'].async_see = _receive_data
# Restore previously loaded devices
dev_reg = await device_registry.async_get_registry(hass)
dev_ids = {
identifier[1]
for device in dev_reg.devices.values()
for identifier in device.identifiers
if identifier[0] == OT_DOMAIN
}
if not dev_ids:
return True
entities = []
for dev_id in dev_ids:
entity = hass.data[OT_DOMAIN]['devices'][dev_id] = OwnTracksEntity(
dev_id
)
entities.append(entity)
async_add_entities(entities)
return True return True
class OwnTracksEntity(DeviceTrackerEntity): class OwnTracksEntity(DeviceTrackerEntity, RestoreEntity):
"""Represent a tracked device.""" """Represent a tracked device."""
def __init__(self, dev_id, host_name, gps, attributes, gps_accuracy, def __init__(self, dev_id, data=None):
battery, source_type, location_name):
"""Set up OwnTracks entity.""" """Set up OwnTracks entity."""
self._dev_id = dev_id self._dev_id = dev_id
self._host_name = host_name self._data = data or {}
self._gps = gps
self._gps_accuracy = gps_accuracy
self._location_name = location_name
self._attributes = attributes
self._battery = battery
self._source_type = source_type
self.entity_id = ENTITY_ID_FORMAT.format(dev_id) self.entity_id = ENTITY_ID_FORMAT.format(dev_id)
@property @property
@ -71,43 +79,45 @@ class OwnTracksEntity(DeviceTrackerEntity):
@property @property
def battery_level(self): def battery_level(self):
"""Return the battery level of the device.""" """Return the battery level of the device."""
return self._battery return self._data.get('battery')
@property @property
def device_state_attributes(self): def device_state_attributes(self):
"""Return device specific attributes.""" """Return device specific attributes."""
return self._attributes return self._data.get('attributes')
@property @property
def location_accuracy(self): def location_accuracy(self):
"""Return the gps accuracy of the device.""" """Return the gps accuracy of the device."""
return self._gps_accuracy return self._data.get('gps_accuracy')
@property @property
def latitude(self): def latitude(self):
"""Return latitude value of the device.""" """Return latitude value of the device."""
if self._gps is not None: # Check with "get" instead of "in" because value can be None
return self._gps[0] if self._data.get('gps'):
return self._data['gps'][0]
return None return None
@property @property
def longitude(self): def longitude(self):
"""Return longitude value of the device.""" """Return longitude value of the device."""
if self._gps is not None: # Check with "get" instead of "in" because value can be None
return self._gps[1] if self._data.get('gps'):
return self._data['gps'][1]
return None return None
@property @property
def location_name(self): def location_name(self):
"""Return a location name for the current location of the device.""" """Return a location name for the current location of the device."""
return self._location_name return self._data.get('location_name')
@property @property
def name(self): def name(self):
"""Return the name of the device.""" """Return the name of the device."""
return self._host_name return self._data.get('host_name')
@property @property
def should_poll(self): def should_poll(self):
@ -117,26 +127,40 @@ class OwnTracksEntity(DeviceTrackerEntity):
@property @property
def source_type(self): def source_type(self):
"""Return the source type, eg gps or router, of the device.""" """Return the source type, eg gps or router, of the device."""
return self._source_type return self._data.get('source_type')
@property @property
def device_info(self): def device_info(self):
"""Return the device info.""" """Return the device info."""
return { return {
'name': self._host_name, 'name': self.name,
'identifiers': {(OT_DOMAIN, self._dev_id)}, 'identifiers': {(OT_DOMAIN, self._dev_id)},
} }
@callback async def async_added_to_hass(self):
def update_data(self, host_name, gps, attributes, gps_accuracy, """Call when entity about to be added to Home Assistant."""
battery, source_type, location_name): await super().async_added_to_hass()
"""Mark the device as seen."""
self._host_name = host_name
self._gps = gps
self._gps_accuracy = gps_accuracy
self._location_name = location_name
self._attributes = attributes
self._battery = battery
self._source_type = source_type
# Don't restore if we got set up with data.
if self._data:
return
state = await self.async_get_last_state()
if state is None:
return
attr = state.attributes
self._data = {
'host_name': state.name,
'gps': (attr[ATTR_LATITUDE], attr[ATTR_LONGITUDE]),
'gps_accuracy': attr[ATTR_GPS_ACCURACY],
'battery': attr[ATTR_BATTERY_LEVEL],
'source_type': attr[ATTR_SOURCE_TYPE],
}
@callback
def update_data(self, data):
"""Mark the device as seen."""
self._data = data
self.async_write_ha_state() self.async_write_ha_state()

View File

@ -1491,3 +1491,47 @@ async def test_region_mapping(hass, setup_comp):
await send_message(hass, EVENT_TOPIC, message) await send_message(hass, EVENT_TOPIC, message)
assert_location_state(hass, 'inner') assert_location_state(hass, 'inner')
async def test_restore_state(hass, hass_client):
"""Test that we can restore state."""
entry = MockConfigEntry(domain='owntracks', data={
'webhook_id': 'owntracks_test',
'secret': 'abcd',
})
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
client = await hass_client()
resp = await client.post(
'/api/webhook/owntracks_test',
json=LOCATION_MESSAGE,
headers={
'X-Limit-u': 'Paulus',
'X-Limit-d': 'Pixel',
}
)
assert resp.status == 200
await hass.async_block_till_done()
state_1 = hass.states.get('device_tracker.paulus_pixel')
assert state_1 is not None
await hass.config_entries.async_reload(entry.entry_id)
await hass.async_block_till_done()
state_2 = hass.states.get('device_tracker.paulus_pixel')
assert state_2 is not None
assert state_1 is not state_2
assert state_1.state == state_2.state
assert state_1.name == state_2.name
assert state_1.attributes['latitude'] == state_2.attributes['latitude']
assert state_1.attributes['longitude'] == state_2.attributes['longitude']
assert state_1.attributes['battery_level'] == \
state_2.attributes['battery_level']
assert state_1.attributes['source_type'] == \
state_2.attributes['source_type']