From 4d4fd19f876c9c9bc9729ffc82eb7906c850c655 Mon Sep 17 00:00:00 2001 From: Robert Svensson Date: Sun, 2 Jun 2019 18:24:13 +0200 Subject: [PATCH] Replace pyunifi with aiounifi in UniFi device tracker (#24149) * Replace pyunifi with aiounifi * Fix tests * Add sslcontext * Fix tests * Fix import order --- homeassistant/components/unifi/config_flow.py | 1 + homeassistant/components/unifi/controller.py | 7 +- .../components/unifi/device_tracker.py | 61 +++++++---- homeassistant/components/unifi/manifest.json | 3 +- requirements_all.txt | 5 +- requirements_test_all.txt | 5 +- tests/components/unifi/test_device_tracker.py | 102 +++++++++--------- tests/components/unifi/test_init.py | 6 +- 8 files changed, 103 insertions(+), 87 deletions(-) diff --git a/homeassistant/components/unifi/config_flow.py b/homeassistant/components/unifi/config_flow.py index b784aaa705a..95af8376773 100644 --- a/homeassistant/components/unifi/config_flow.py +++ b/homeassistant/components/unifi/config_flow.py @@ -84,6 +84,7 @@ class UnifiFlowHandler(config_entries.ConfigFlow): try: desc = user_input.get(CONF_SITE_ID, self.desc) + print(self.sites) for site in self.sites.values(): if desc == site['desc']: if site['role'] != 'admin': diff --git a/homeassistant/components/unifi/controller.py b/homeassistant/components/unifi/controller.py index 2b9aa89fef2..5105e33f1d6 100644 --- a/homeassistant/components/unifi/controller.py +++ b/homeassistant/components/unifi/controller.py @@ -1,5 +1,6 @@ """UniFi Controller abstraction.""" import asyncio +import ssl import async_timeout from aiohttp import CookieJar @@ -81,15 +82,19 @@ async def get_controller( """Create a controller object and verify authentication.""" import aiounifi + sslcontext = None + if verify_ssl: session = aiohttp_client.async_get_clientsession(hass) + if isinstance(verify_ssl, str): + sslcontext = ssl.create_default_context(cafile=verify_ssl) else: session = aiohttp_client.async_create_clientsession( hass, verify_ssl=verify_ssl, cookie_jar=CookieJar(unsafe=True)) controller = aiounifi.Controller( host, username=username, password=password, port=port, site=site, - websession=session + websession=session, sslcontext=sslcontext ) try: diff --git a/homeassistant/components/unifi/device_tracker.py b/homeassistant/components/unifi/device_tracker.py index 8bf384eef14..30754273254 100644 --- a/homeassistant/components/unifi/device_tracker.py +++ b/homeassistant/components/unifi/device_tracker.py @@ -1,8 +1,13 @@ """Support for Unifi WAP controllers.""" +import asyncio import logging from datetime import timedelta import voluptuous as vol +import async_timeout + +import aiounifi + import homeassistant.helpers.config_validation as cv from homeassistant.components.device_tracker import ( DOMAIN, PLATFORM_SCHEMA, DeviceScanner) @@ -10,6 +15,9 @@ from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD from homeassistant.const import CONF_VERIFY_SSL, CONF_MONITORED_CONDITIONS import homeassistant.util.dt as dt_util +from .controller import get_controller +from .errors import AuthenticationRequired, CannotConnect + _LOGGER = logging.getLogger(__name__) CONF_PORT = 'port' CONF_SITE_ID = 'site_id' @@ -54,10 +62,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ }) -def get_scanner(hass, config): +async def async_get_scanner(hass, config): """Set up the Unifi device_tracker.""" - from pyunifi.controller import Controller, APIError - host = config[DOMAIN].get(CONF_HOST) username = config[DOMAIN].get(CONF_USERNAME) password = config[DOMAIN].get(CONF_PASSWORD) @@ -69,9 +75,11 @@ def get_scanner(hass, config): ssid_filter = config[DOMAIN].get(CONF_SSID_FILTER) try: - ctrl = Controller(host, username, password, port, version='v4', - site_id=site_id, ssl_verify=verify_ssl) - except APIError as ex: + controller = await get_controller( + hass, host, username, password, port, site_id, verify_ssl) + await controller.initialize() + + except (AuthenticationRequired, CannotConnect) as ex: _LOGGER.error("Failed to connect to Unifi: %s", ex) hass.components.persistent_notification.create( 'Failed to connect to Unifi. ' @@ -82,8 +90,8 @@ def get_scanner(hass, config): notification_id=NOTIFICATION_ID) return False - return UnifiScanner(ctrl, detection_time, ssid_filter, - monitored_conditions) + return UnifiScanner( + controller, detection_time, ssid_filter, monitored_conditions) class UnifiScanner(DeviceScanner): @@ -92,36 +100,45 @@ class UnifiScanner(DeviceScanner): def __init__(self, controller, detection_time: timedelta, ssid_filter, monitored_conditions) -> None: """Initialize the scanner.""" + self.controller = controller self._detection_time = detection_time - self._controller = controller self._ssid_filter = ssid_filter self._monitored_conditions = monitored_conditions - self._update() + self._clients = {} - def _update(self): + async def async_update(self): """Get the clients from the device.""" - from pyunifi.controller import APIError try: - clients = self._controller.get_clients() - except APIError as ex: - _LOGGER.error("Failed to scan clients: %s", ex) + await self.controller.clients.update() + clients = self.controller.clients.values() + + except aiounifi.LoginRequired: + try: + with async_timeout.timeout(5): + await self.controller.login() + except (asyncio.TimeoutError, aiounifi.AiounifiException): + clients = [] + + except aiounifi.AiounifiException: clients = [] # Filter clients to provided SSID list if self._ssid_filter: - clients = [client for client in clients - if 'essid' in client and - client['essid'] in self._ssid_filter] + clients = [ + client for client in clients + if client.essid in self._ssid_filter + ] self._clients = { - client['mac']: client + client.raw['mac']: client.raw for client in clients if (dt_util.utcnow() - dt_util.utc_from_timestamp(float( - client['last_seen']))) < self._detection_time} + client.last_seen))) < self._detection_time + } - def scan_devices(self): + async def async_scan_devices(self): """Scan for devices.""" - self._update() + await self.async_update() return self._clients.keys() def get_device_name(self, device): diff --git a/homeassistant/components/unifi/manifest.json b/homeassistant/components/unifi/manifest.json index 22ece5addaf..64119bae2fe 100644 --- a/homeassistant/components/unifi/manifest.json +++ b/homeassistant/components/unifi/manifest.json @@ -4,8 +4,7 @@ "config_flow": true, "documentation": "https://www.home-assistant.io/components/unifi", "requirements": [ - "aiounifi==4", - "pyunifi==2.16" + "aiounifi==6" ], "dependencies": [], "codeowners": [ diff --git a/requirements_all.txt b/requirements_all.txt index 6c4421ed2b3..a40bb6d4e19 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -163,7 +163,7 @@ aiopvapi==1.6.14 aioswitcher==2019.3.21 # homeassistant.components.unifi -aiounifi==4 +aiounifi==6 # homeassistant.components.aladdin_connect aladdin_connect==0.3 @@ -1488,9 +1488,6 @@ pytrafikverket==0.1.5.9 # homeassistant.components.ubee pyubee==0.6 -# homeassistant.components.unifi -pyunifi==2.16 - # homeassistant.components.uptimerobot pyuptimerobot==0.0.5 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 4fde162bdd9..d945ad1628e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -61,7 +61,7 @@ aiohue==1.9.1 aioswitcher==2019.3.21 # homeassistant.components.unifi -aiounifi==4 +aiounifi==6 # homeassistant.components.ambiclimate ambiclimate==0.1.2 @@ -294,9 +294,6 @@ python_awair==0.0.4 # homeassistant.components.tradfri pytradfri[async]==6.0.1 -# homeassistant.components.unifi -pyunifi==2.16 - # homeassistant.components.html5 pywebpush==1.9.2 diff --git a/tests/components/unifi/test_device_tracker.py b/tests/components/unifi/test_device_tracker.py index 0fb0751c5b6..5bc24c6c269 100644 --- a/tests/components/unifi/test_device_tracker.py +++ b/tests/components/unifi/test_device_tracker.py @@ -1,8 +1,6 @@ """The tests for the Unifi WAP device tracker platform.""" from unittest import mock from datetime import datetime, timedelta -from pyunifi.controller import APIError - import pytest import voluptuous as vol @@ -13,13 +11,20 @@ import homeassistant.components.unifi.device_tracker as unifi from homeassistant.const import (CONF_HOST, CONF_USERNAME, CONF_PASSWORD, CONF_PLATFORM, CONF_VERIFY_SSL, CONF_MONITORED_CONDITIONS) + +from tests.common import mock_coro +from asynctest import CoroutineMock +from aiounifi.clients import Clients + DEFAULT_DETECTION_TIME = timedelta(seconds=300) @pytest.fixture def mock_ctrl(): """Mock pyunifi.""" - with mock.patch('pyunifi.controller.Controller') as mock_control: + with mock.patch('aiounifi.Controller') as mock_control: + mock_control.return_value.login.return_value = mock_coro() + mock_control.return_value.initialize.return_value = mock_coro() yield mock_control @@ -33,7 +38,7 @@ def mock_scanner(): @mock.patch('os.access', return_value=True) @mock.patch('os.path.isfile', mock.Mock(return_value=True)) -def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): +async def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): """Test the setup with a string for ssl_verify. Representing the absolute path to a CA certificate bundle. @@ -46,12 +51,9 @@ def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): CONF_VERIFY_SSL: "/tmp/unifi.crt" }) } - result = unifi.get_scanner(hass, config) + result = await unifi.async_get_scanner(hass, config) assert mock_scanner.return_value == result assert mock_ctrl.call_count == 1 - assert mock_ctrl.mock_calls[0] == \ - mock.call('localhost', 'foo', 'password', 8443, - version='v4', site_id='default', ssl_verify="/tmp/unifi.crt") assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, @@ -59,7 +61,7 @@ def test_config_valid_verify_ssl(hass, mock_scanner, mock_ctrl): None, None) -def test_config_minimal(hass, mock_scanner, mock_ctrl): +async def test_config_minimal(hass, mock_scanner, mock_ctrl): """Test the setup with minimal configuration.""" config = { DOMAIN: unifi.PLATFORM_SCHEMA({ @@ -68,12 +70,10 @@ def test_config_minimal(hass, mock_scanner, mock_ctrl): CONF_PASSWORD: 'password', }) } - result = unifi.get_scanner(hass, config) + + result = await unifi.async_get_scanner(hass, config) assert mock_scanner.return_value == result assert mock_ctrl.call_count == 1 - assert mock_ctrl.mock_calls[0] == \ - mock.call('localhost', 'foo', 'password', 8443, - version='v4', site_id='default', ssl_verify=True) assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call(mock_ctrl.return_value, @@ -81,7 +81,7 @@ def test_config_minimal(hass, mock_scanner, mock_ctrl): None, None) -def test_config_full(hass, mock_scanner, mock_ctrl): +async def test_config_full(hass, mock_scanner, mock_ctrl): """Test the setup with full configuration.""" config = { DOMAIN: unifi.PLATFORM_SCHEMA({ @@ -96,12 +96,9 @@ def test_config_full(hass, mock_scanner, mock_ctrl): 'detection_time': 300, }) } - result = unifi.get_scanner(hass, config) + result = await unifi.async_get_scanner(hass, config) assert mock_scanner.return_value == result assert mock_ctrl.call_count == 1 - assert mock_ctrl.call_args == \ - mock.call('myhost', 'foo', 'password', 123, - version='v4', site_id='abcdef01', ssl_verify=False) assert mock_scanner.call_count == 1 assert mock_scanner.call_args == mock.call( @@ -137,7 +134,7 @@ def test_config_error(): }) -def test_config_controller_failed(hass, mock_ctrl, mock_scanner): +async def test_config_controller_failed(hass, mock_ctrl, mock_scanner): """Test for controller failure.""" config = { 'device_tracker': { @@ -146,13 +143,12 @@ def test_config_controller_failed(hass, mock_ctrl, mock_scanner): CONF_PASSWORD: 'password', } } - mock_ctrl.side_effect = APIError( - '/', 500, 'foo', {}, None) - result = unifi.get_scanner(hass, config) + mock_ctrl.side_effect = unifi.CannotConnect + result = await unifi.async_get_scanner(hass, config) assert result is False -def test_scanner_update(): +async def test_scanner_update(): """Test the scanner update.""" ctrl = mock.MagicMock() fake_clients = [ @@ -161,21 +157,20 @@ def test_scanner_update(): {'mac': '234', 'essid': 'barnet', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, ] - ctrl.get_clients.return_value = fake_clients - unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) - assert ctrl.get_clients.call_count == 1 - assert ctrl.get_clients.call_args == mock.call() + ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients)) + scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) + await scnr.async_update() + assert len(scnr._clients) == 2 def test_scanner_update_error(): """Test the scanner update for error.""" ctrl = mock.MagicMock() - ctrl.get_clients.side_effect = APIError( - '/', 500, 'foo', {}, None) + ctrl.get_clients.side_effect = unifi.aiounifi.AiounifiException unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) -def test_scan_devices(): +async def test_scan_devices(): """Test the scanning for devices.""" ctrl = mock.MagicMock() fake_clients = [ @@ -184,12 +179,13 @@ def test_scan_devices(): {'mac': '234', 'essid': 'barnet', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, ] - ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) - assert set(scanner.scan_devices()) == set(['123', '234']) + ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients)) + scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) + await scnr.async_update() + assert set(await scnr.async_scan_devices()) == set(['123', '234']) -def test_scan_devices_filtered(): +async def test_scan_devices_filtered(): """Test the scanning for devices based on SSID.""" ctrl = mock.MagicMock() fake_clients = [ @@ -204,13 +200,13 @@ def test_scan_devices_filtered(): ] ssid_filter = ['foonet', 'barnet'] - ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, ssid_filter, - None) - assert set(scanner.scan_devices()) == set(['123', '234', '890']) + ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients)) + scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, ssid_filter, None) + await scnr.async_update() + assert set(await scnr.async_scan_devices()) == set(['123', '234', '890']) -def test_get_device_name(): +async def test_get_device_name(): """Test the getting of device names.""" ctrl = mock.MagicMock() fake_clients = [ @@ -226,15 +222,16 @@ def test_get_device_name(): 'essid': 'barnet', 'last_seen': '1504786810'}, ] - ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) - assert scanner.get_device_name('123') == 'foobar' - assert scanner.get_device_name('234') == 'Nice Name' - assert scanner.get_device_name('456') is None - assert scanner.get_device_name('unknown') is None + ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients)) + scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, None) + await scnr.async_update() + assert scnr.get_device_name('123') == 'foobar' + assert scnr.get_device_name('234') == 'Nice Name' + assert scnr.get_device_name('456') is None + assert scnr.get_device_name('unknown') is None -def test_monitored_conditions(): +async def test_monitored_conditions(): """Test the filtering of attributes.""" ctrl = mock.MagicMock() fake_clients = [ @@ -254,16 +251,17 @@ def test_monitored_conditions(): 'essid': 'barnet', 'last_seen': dt_util.as_timestamp(dt_util.utcnow())}, ] - ctrl.get_clients.return_value = fake_clients - scanner = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, - ['essid', 'signal', 'latest_assoc_time']) - assert scanner.get_extra_attributes('123') == { + ctrl.clients = Clients([], CoroutineMock(return_value=fake_clients)) + scnr = unifi.UnifiScanner(ctrl, DEFAULT_DETECTION_TIME, None, + ['essid', 'signal', 'latest_assoc_time']) + await scnr.async_update() + assert scnr.get_extra_attributes('123') == { 'essid': 'barnet', 'signal': -60, 'latest_assoc_time': datetime(2000, 1, 1, 0, 0, tzinfo=dt_util.UTC) } - assert scanner.get_extra_attributes('234') == { + assert scnr.get_extra_attributes('234') == { 'essid': 'barnet', 'signal': -42 } - assert scanner.get_extra_attributes('456') == {'essid': 'barnet'} + assert scnr.get_extra_attributes('456') == {'essid': 'barnet'} diff --git a/tests/components/unifi/test_init.py b/tests/components/unifi/test_init.py index d2d19204b40..ec5ab5a577b 100644 --- a/tests/components/unifi/test_init.py +++ b/tests/components/unifi/test_init.py @@ -146,7 +146,8 @@ async def test_flow_works(hass, aioclient_mock): flow.hass = hass with patch('aiounifi.Controller') as mock_controller: - def mock_constructor(host, username, password, port, site, websession): + def mock_constructor( + host, username, password, port, site, websession, sslcontext): """Fake the controller constructor.""" mock_controller.host = host mock_controller.username = username @@ -254,7 +255,8 @@ async def test_user_permissions_low(hass, aioclient_mock): flow.hass = hass with patch('aiounifi.Controller') as mock_controller: - def mock_constructor(host, username, password, port, site, websession): + def mock_constructor( + host, username, password, port, site, websession, sslcontext): """Fake the controller constructor.""" mock_controller.host = host mock_controller.username = username