From f8651d9faa87cec74985de0751d829175489c557 Mon Sep 17 00:00:00 2001 From: Rami Mosleh Date: Tue, 7 Jul 2020 01:18:56 +0300 Subject: [PATCH] Stop Speedtest sensors update on startup if manual option is enabled (#37403) Co-authored-by: Paulus Schoutsen --- .../components/speedtestdotnet/__init__.py | 28 ++++++++---- .../components/speedtestdotnet/config_flow.py | 4 +- .../components/speedtestdotnet/sensor.py | 44 +++++++++++++------ 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/speedtestdotnet/__init__.py b/homeassistant/components/speedtestdotnet/__init__.py index 3cad15a0967..1b50516f340 100644 --- a/homeassistant/components/speedtestdotnet/__init__.py +++ b/homeassistant/components/speedtestdotnet/__init__.py @@ -70,9 +70,10 @@ async def async_setup_entry(hass, config_entry): coordinator = SpeedTestDataCoordinator(hass, config_entry) await coordinator.async_setup() - await coordinator.async_refresh() - if not coordinator.last_update_success: - raise ConfigEntryNotReady + if not config_entry.options[CONF_MANUAL]: + await coordinator.async_refresh() + if not coordinator.last_update_success: + raise ConfigEntryNotReady hass.data[DOMAIN] = coordinator @@ -115,9 +116,12 @@ class SpeedTestDataCoordinator(DataUpdateCoordinator): ), ) - def update_data(self): - """Get the latest data from speedtest.net.""" - server_list = self.api.get_servers() + def update_servers(self): + """Update list of test servers.""" + try: + server_list = self.api.get_servers() + except speedtest.ConfigRetrievalError: + return self.servers[DEFAULT_SERVER] = {} for server in sorted( @@ -125,14 +129,20 @@ class SpeedTestDataCoordinator(DataUpdateCoordinator): ): self.servers[f"{server[0]['country']} - {server[0]['sponsor']}"] = server[0] + def update_data(self): + """Get the latest data from speedtest.net.""" + self.update_servers() + + self.api.closest.clear() if self.config_entry.options.get(CONF_SERVER_ID): server_id = self.config_entry.options.get(CONF_SERVER_ID) - self.api.closest.clear() self.api.get_servers(servers=[server_id]) + + self.api.get_best_server() _LOGGER.debug( "Executing speedtest.net speed test with server_id: %s", self.api.best["id"] ) - self.api.get_best_server() + self.api.download() self.api.upload() return self.api.results.dict() @@ -170,6 +180,8 @@ class SpeedTestDataCoordinator(DataUpdateCoordinator): await self.async_set_options() + await self.hass.async_add_executor_job(self.update_servers) + self.hass.services.async_register(DOMAIN, SPEED_TEST_SERVICE, request_update) self.config_entry.add_update_listener(options_updated_listener) diff --git a/homeassistant/components/speedtestdotnet/config_flow.py b/homeassistant/components/speedtestdotnet/config_flow.py index 1d8f3cf189b..57076c2a90b 100644 --- a/homeassistant/components/speedtestdotnet/config_flow.py +++ b/homeassistant/components/speedtestdotnet/config_flow.py @@ -85,7 +85,7 @@ class SpeedTestOptionsFlowHandler(config_entries.OptionsFlow): self._servers = self.hass.data[DOMAIN].servers - server_name = DEFAULT_SERVER + server = [] if self.config_entry.options.get( CONF_SERVER_ID ) and not self.config_entry.options.get(CONF_SERVER_NAME): @@ -94,7 +94,7 @@ class SpeedTestOptionsFlowHandler(config_entries.OptionsFlow): for (key, value) in self._servers.items() if value.get("id") == self.config_entry.options[CONF_SERVER_ID] ] - server_name = server[0] if server else "" + server_name = server[0] if server else DEFAULT_SERVER options = { vol.Optional( diff --git a/homeassistant/components/speedtestdotnet/sensor.py b/homeassistant/components/speedtestdotnet/sensor.py index 06868dc1437..0889d7da5b2 100644 --- a/homeassistant/components/speedtestdotnet/sensor.py +++ b/homeassistant/components/speedtestdotnet/sensor.py @@ -2,7 +2,8 @@ import logging from homeassistant.const import ATTR_ATTRIBUTION -from homeassistant.helpers.entity import Entity +from homeassistant.core import callback +from homeassistant.helpers.restore_state import RestoreEntity from .const import ( ATTR_BYTES_RECEIVED, @@ -11,6 +12,7 @@ from .const import ( ATTR_SERVER_ID, ATTR_SERVER_NAME, ATTRIBUTION, + CONF_MANUAL, DEFAULT_NAME, DOMAIN, ICON, @@ -32,7 +34,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): async_add_entities(entities) -class SpeedtestSensor(Entity): +class SpeedtestSensor(RestoreEntity): """Implementation of a speedtest.net sensor.""" def __init__(self, coordinator, sensor_type): @@ -41,6 +43,7 @@ class SpeedtestSensor(Entity): self.coordinator = coordinator self.type = sensor_type self._unit_of_measurement = SENSOR_TYPES[self.type][1] + self._state = None @property def name(self): @@ -55,14 +58,7 @@ class SpeedtestSensor(Entity): @property def state(self): """Return the state of the device.""" - state = None - if self.type == "ping": - state = self.coordinator.data["ping"] - elif self.type == "download": - state = round(self.coordinator.data["download"] / 10 ** 6, 2) - elif self.type == "upload": - state = round(self.coordinator.data["upload"] / 10 ** 6, 2) - return state + return self._state @property def unit_of_measurement(self): @@ -82,6 +78,8 @@ class SpeedtestSensor(Entity): @property def device_state_attributes(self): """Return the state attributes.""" + if not self.coordinator.data: + return None attributes = { ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_SERVER_NAME: self.coordinator.data["server"]["name"], @@ -98,10 +96,30 @@ class SpeedtestSensor(Entity): async def async_added_to_hass(self): """Handle entity which will be added.""" + await super().async_added_to_hass() + if self.coordinator.config_entry.options[CONF_MANUAL]: + state = await self.async_get_last_state() + if state: + self._state = state.state - self.async_on_remove( - self.coordinator.async_add_listener(self.async_write_ha_state) - ) + @callback + def update(): + """Update state.""" + self._update_state() + self.async_write_ha_state() + + self.async_on_remove(self.coordinator.async_add_listener(update)) + self._update_state() + + def _update_state(self): + """Update sensors state.""" + if self.coordinator.data: + if self.type == "ping": + self._state = self.coordinator.data["ping"] + elif self.type == "download": + self._state = round(self.coordinator.data["download"] / 10 ** 6, 2) + elif self.type == "upload": + self._state = round(self.coordinator.data["upload"] / 10 ** 6, 2) async def async_update(self): """Request coordinator to update data."""