diff --git a/homeassistant/components/emulated_hue/__init__.py b/homeassistant/components/emulated_hue/__init__.py index c89e4fda358..09ce1a57060 100644 --- a/homeassistant/components/emulated_hue/__init__.py +++ b/homeassistant/components/emulated_hue/__init__.py @@ -4,7 +4,6 @@ Support for local control of entities by emulating the Phillips Hue bridge. For more details about this component, please refer to the documentation at https://home-assistant.io/components/emulated_hue/ """ -import asyncio import logging import voluptuous as vol @@ -111,17 +110,15 @@ def setup(hass, yaml_config): config.upnp_bind_multicast, config.advertise_ip, config.advertise_port) - @asyncio.coroutine - def stop_emulated_hue_bridge(event): + async def stop_emulated_hue_bridge(event): """Stop the emulated hue bridge.""" upnp_listener.stop() - yield from server.stop() + await server.stop() - @asyncio.coroutine - def start_emulated_hue_bridge(event): + async def start_emulated_hue_bridge(event): """Start the emulated hue bridge.""" upnp_listener.start() - yield from server.start() + await server.start() hass.bus.async_listen_once( EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 1d4306565b1..4d313b5132e 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -4,21 +4,18 @@ This module provides WSGI application to serve the Home Assistant API. For more details about this component, please refer to the documentation at https://home-assistant.io/components/http/ """ -import asyncio + from ipaddress import ip_network -import json import logging import os import ssl from aiohttp import web -from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently +from aiohttp.web_exceptions import HTTPMovedPermanently import voluptuous as vol from homeassistant.const import ( - SERVER_PORT, CONTENT_TYPE_JSON, - EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,) -from homeassistant.core import is_callback + EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVER_PORT) import homeassistant.helpers.config_validation as cv import homeassistant.remote as rem import homeassistant.util as hass_util @@ -28,10 +25,13 @@ from .auth import setup_auth from .ban import setup_bans from .cors import setup_cors from .real_ip import setup_real_ip -from .const import KEY_AUTHENTICATED, KEY_REAL_IP from .static import ( CachingFileResponse, CachingStaticResource, staticresource_middleware) +# Import as alias +from .const import KEY_AUTHENTICATED, KEY_REAL_IP # noqa +from .view import HomeAssistantView # noqa + REQUIREMENTS = ['aiohttp_cors==0.6.0'] DOMAIN = 'http' @@ -98,8 +98,7 @@ CONFIG_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) -@asyncio.coroutine -def async_setup(hass, config): +async def async_setup(hass, config): """Set up the HTTP API and debug interface.""" conf = config.get(DOMAIN) @@ -135,16 +134,14 @@ def async_setup(hass, config): is_ban_enabled=is_ban_enabled ) - @asyncio.coroutine - def stop_server(event): + async def stop_server(event): """Stop the server.""" - yield from server.stop() + await server.stop() - @asyncio.coroutine - def start_server(event): + async def start_server(event): """Start the server.""" hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) - yield from server.start() + await server.start() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server) @@ -252,13 +249,11 @@ class HomeAssistantHTTP(object): return if cache_headers: - @asyncio.coroutine - def serve_file(request): + async def serve_file(request): """Serve file from disk.""" return CachingFileResponse(path) else: - @asyncio.coroutine - def serve_file(request): + async def serve_file(request): """Serve file from disk.""" return web.FileResponse(path) @@ -276,14 +271,13 @@ class HomeAssistantHTTP(object): self.app.router.add_route('GET', url_pattern, serve_file) - @asyncio.coroutine - def start(self): + async def start(self): """Start the WSGI server.""" # We misunderstood the startup signal. You're not allowed to change # anything during startup. Temp workaround. # pylint: disable=protected-access self.app._on_startup.freeze() - yield from self.app.startup() + await self.app.startup() if self.ssl_certificate: try: @@ -308,121 +302,18 @@ class HomeAssistantHTTP(object): self._handler = self.app.make_handler(loop=self.hass.loop) try: - self.server = yield from self.hass.loop.create_server( + self.server = await self.hass.loop.create_server( self._handler, self.server_host, self.server_port, ssl=context) except OSError as error: _LOGGER.error("Failed to create HTTP server at port %d: %s", self.server_port, error) - @asyncio.coroutine - def stop(self): + async def stop(self): """Stop the WSGI server.""" if self.server: self.server.close() - yield from self.server.wait_closed() - yield from self.app.shutdown() + await self.server.wait_closed() + await self.app.shutdown() if self._handler: - yield from self._handler.shutdown(10) - yield from self.app.cleanup() - - -class HomeAssistantView(object): - """Base view for all views.""" - - url = None - extra_urls = [] - requires_auth = True # Views inheriting from this class can override this - - # pylint: disable=no-self-use - def json(self, result, status_code=200, headers=None): - """Return a JSON response.""" - msg = json.dumps( - result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8') - response = web.Response( - body=msg, content_type=CONTENT_TYPE_JSON, status=status_code, - headers=headers) - response.enable_compression() - return response - - def json_message(self, message, status_code=200, message_code=None, - headers=None): - """Return a JSON message response.""" - data = {'message': message} - if message_code is not None: - data['code'] = message_code - return self.json(data, status_code, headers=headers) - - @asyncio.coroutine - # pylint: disable=no-self-use - def file(self, request, fil): - """Return a file.""" - assert isinstance(fil, str), 'only string paths allowed' - return web.FileResponse(fil) - - def register(self, router): - """Register the view with a router.""" - assert self.url is not None, 'No url set for view' - urls = [self.url] + self.extra_urls - - for method in ('get', 'post', 'delete', 'put'): - handler = getattr(self, method, None) - - if not handler: - continue - - handler = request_handler_factory(self, handler) - - for url in urls: - router.add_route(method, url, handler) - - # aiohttp_cors does not work with class based views - # self.app.router.add_route('*', self.url, self, name=self.name) - - # for url in self.extra_urls: - # self.app.router.add_route('*', url, self) - - -def request_handler_factory(view, handler): - """Wrap the handler classes.""" - assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \ - "Handler should be a coroutine or a callback." - - @asyncio.coroutine - def handle(request): - """Handle incoming request.""" - if not request.app['hass'].is_running: - return web.Response(status=503) - - authenticated = request.get(KEY_AUTHENTICATED, False) - - if view.requires_auth and not authenticated: - raise HTTPUnauthorized() - - _LOGGER.info('Serving %s to %s (auth: %s)', - request.path, request.get(KEY_REAL_IP), authenticated) - - result = handler(request, **request.match_info) - - if asyncio.iscoroutine(result): - result = yield from result - - if isinstance(result, web.StreamResponse): - # The method handler returned a ready-made Response, how nice of it - return result - - status_code = 200 - - if isinstance(result, tuple): - result, status_code = result - - if isinstance(result, str): - result = result.encode('utf-8') - elif result is None: - result = b'' - elif not isinstance(result, bytes): - assert False, ('Result should be None, string, bytes or Response. ' - 'Got: {}').format(result) - - return web.Response(body=result, status=status_code) - - return handle + await self._handler.shutdown(10) + await self.app.cleanup() diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index 3128489437a..65c70c37bd2 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -1,5 +1,5 @@ """Authentication for HTTP component.""" -import asyncio + import base64 import hmac import logging @@ -20,13 +20,12 @@ _LOGGER = logging.getLogger(__name__) def setup_auth(app, trusted_networks, api_password): """Create auth middleware for the app.""" @middleware - @asyncio.coroutine - def auth_middleware(request, handler): + async def auth_middleware(request, handler): """Authenticate as middleware.""" # If no password set, just always set authenticated=True if api_password is None: request[KEY_AUTHENTICATED] = True - return (yield from handler(request)) + return await handler(request) # Check authentication authenticated = False @@ -50,10 +49,9 @@ def setup_auth(app, trusted_networks, api_password): authenticated = True request[KEY_AUTHENTICATED] = authenticated - return (yield from handler(request)) + return await handler(request) - @asyncio.coroutine - def auth_startup(app): + async def auth_startup(app): """Initialize auth middleware when app starts up.""" app.middlewares.append(auth_middleware) diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index 4c797b05b19..fe8b7db84d1 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -1,5 +1,5 @@ """Ban logic for HTTP component.""" -import asyncio + from collections import defaultdict from datetime import datetime from ipaddress import ip_address @@ -38,11 +38,10 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({ @callback def setup_bans(hass, app, login_threshold): """Create IP Ban middleware for the app.""" - @asyncio.coroutine - def ban_startup(app): + async def ban_startup(app): """Initialize bans when app starts up.""" app.middlewares.append(ban_middleware) - app[KEY_BANNED_IPS] = yield from hass.async_add_job( + app[KEY_BANNED_IPS] = await hass.async_add_job( load_ip_bans_config, hass.config.path(IP_BANS_FILE)) app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) app[KEY_LOGIN_THRESHOLD] = login_threshold @@ -51,12 +50,11 @@ def setup_bans(hass, app, login_threshold): @middleware -@asyncio.coroutine -def ban_middleware(request, handler): +async def ban_middleware(request, handler): """IP Ban middleware.""" if KEY_BANNED_IPS not in request.app: _LOGGER.error('IP Ban middleware loaded but banned IPs not loaded') - return (yield from handler(request)) + return await handler(request) # Verify if IP is not banned ip_address_ = request[KEY_REAL_IP] @@ -67,14 +65,13 @@ def ban_middleware(request, handler): raise HTTPForbidden() try: - return (yield from handler(request)) + return await handler(request) except HTTPUnauthorized: - yield from process_wrong_login(request) + await process_wrong_login(request) raise -@asyncio.coroutine -def process_wrong_login(request): +async def process_wrong_login(request): """Process a wrong login attempt.""" remote_addr = request[KEY_REAL_IP] @@ -98,7 +95,7 @@ def process_wrong_login(request): request.app[KEY_BANNED_IPS].append(new_ban) hass = request.app['hass'] - yield from hass.async_add_job( + await hass.async_add_job( update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban) _LOGGER.warning( diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py index 2eb92732d1e..0a37f22867e 100644 --- a/homeassistant/components/http/cors.py +++ b/homeassistant/components/http/cors.py @@ -1,5 +1,5 @@ """Provide cors support for the HTTP component.""" -import asyncio + from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE @@ -27,8 +27,7 @@ def setup_cors(app, origins): ) for host in origins }) - @asyncio.coroutine - def cors_startup(app): + async def cors_startup(app): """Initialize cors when app starts up.""" cors_added = set() diff --git a/homeassistant/components/http/data_validator.py b/homeassistant/components/http/data_validator.py index 528c0a598e3..8fc7cd8e658 100644 --- a/homeassistant/components/http/data_validator.py +++ b/homeassistant/components/http/data_validator.py @@ -1,5 +1,5 @@ """Decorator for view methods to help with data validation.""" -import asyncio + from functools import wraps import logging @@ -24,16 +24,15 @@ class RequestDataValidator: def __call__(self, method): """Decorate a function.""" - @asyncio.coroutine @wraps(method) - def wrapper(view, request, *args, **kwargs): + async def wrapper(view, request, *args, **kwargs): """Wrap a request handler with data validation.""" data = None try: - data = yield from request.json() + data = await request.json() except ValueError: if not self._allow_empty or \ - (yield from request.content.read()) != b'': + (await request.content.read()) != b'': _LOGGER.error('Invalid JSON received.') return view.json_message('Invalid JSON.', 400) data = {} @@ -45,7 +44,7 @@ class RequestDataValidator: return view.json_message( 'Message format incorrect: {}'.format(err), 400) - result = yield from method(view, request, *args, **kwargs) + result = await method(view, request, *args, **kwargs) return result return wrapper diff --git a/homeassistant/components/http/real_ip.py b/homeassistant/components/http/real_ip.py index 1e50f33f69e..c394016a683 100644 --- a/homeassistant/components/http/real_ip.py +++ b/homeassistant/components/http/real_ip.py @@ -1,5 +1,5 @@ """Middleware to fetch real IP.""" -import asyncio + from ipaddress import ip_address from aiohttp.web import middleware @@ -14,8 +14,7 @@ from .const import KEY_REAL_IP def setup_real_ip(app, use_x_forwarded_for): """Create IP Ban middleware for the app.""" @middleware - @asyncio.coroutine - def real_ip_middleware(request, handler): + async def real_ip_middleware(request, handler): """Real IP middleware.""" if (use_x_forwarded_for and X_FORWARDED_FOR in request.headers): @@ -25,10 +24,9 @@ def setup_real_ip(app, use_x_forwarded_for): request[KEY_REAL_IP] = \ ip_address(request.transport.get_extra_info('peername')[0]) - return (yield from handler(request)) + return await handler(request) - @asyncio.coroutine - def app_startup(app): + async def app_startup(app): """Initialize bans when app starts up.""" app.middlewares.append(real_ip_middleware) diff --git a/homeassistant/components/http/static.py b/homeassistant/components/http/static.py index f444e4b3180..3fbaf703d06 100644 --- a/homeassistant/components/http/static.py +++ b/homeassistant/components/http/static.py @@ -1,5 +1,5 @@ """Static file handling for HTTP component.""" -import asyncio + import re from aiohttp import hdrs @@ -14,8 +14,7 @@ _FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE) class CachingStaticResource(StaticResource): """Static Resource handler that will add cache headers.""" - @asyncio.coroutine - def _handle(self, request): + async def _handle(self, request): filename = URL(request.match_info['filename']).path try: # PyLint is wrong about resolve not being a member. @@ -32,7 +31,7 @@ class CachingStaticResource(StaticResource): raise HTTPNotFound() from error if filepath.is_dir(): - return (yield from super()._handle(request)) + return await super()._handle(request) elif filepath.is_file(): return CachingFileResponse(filepath, chunk_size=self._chunk_size) else: @@ -49,26 +48,24 @@ class CachingFileResponse(FileResponse): orig_sendfile = self._sendfile - @asyncio.coroutine - def sendfile(request, fobj, count): + async def sendfile(request, fobj, count): """Sendfile that includes a cache header.""" cache_time = 31 * 86400 # = 1 month self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format( cache_time) - yield from orig_sendfile(request, fobj, count) + await orig_sendfile(request, fobj, count) # Overwriting like this because __init__ can change implementation. self._sendfile = sendfile @middleware -@asyncio.coroutine -def staticresource_middleware(request, handler): +async def staticresource_middleware(request, handler): """Middleware to strip out fingerprint from fingerprinted assets.""" path = request.path if not path.startswith('/static/') and not path.startswith('/frontend'): - return (yield from handler(request)) + return await handler(request) fingerprinted = _FINGERPRINT.match(request.match_info['filename']) @@ -76,4 +73,4 @@ def staticresource_middleware(request, handler): request.match_info['filename'] = \ '{}.{}'.format(*fingerprinted.groups()) - return (yield from handler(request)) + return await handler(request) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py new file mode 100644 index 00000000000..299a10e9f5a --- /dev/null +++ b/homeassistant/components/http/view.py @@ -0,0 +1,121 @@ +""" +This module provides WSGI application to serve the Home Assistant API. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/http/ +""" +import asyncio +import json +import logging + +from aiohttp import web +from aiohttp.web_exceptions import HTTPUnauthorized + +import homeassistant.remote as rem +from homeassistant.core import is_callback +from homeassistant.const import CONTENT_TYPE_JSON + +from .const import KEY_AUTHENTICATED, KEY_REAL_IP + + +_LOGGER = logging.getLogger(__name__) + + +class HomeAssistantView(object): + """Base view for all views.""" + + url = None + extra_urls = [] + requires_auth = True # Views inheriting from this class can override this + + # pylint: disable=no-self-use + def json(self, result, status_code=200, headers=None): + """Return a JSON response.""" + msg = json.dumps( + result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8') + response = web.Response( + body=msg, content_type=CONTENT_TYPE_JSON, status=status_code, + headers=headers) + response.enable_compression() + return response + + def json_message(self, message, status_code=200, message_code=None, + headers=None): + """Return a JSON message response.""" + data = {'message': message} + if message_code is not None: + data['code'] = message_code + return self.json(data, status_code, headers=headers) + + # pylint: disable=no-self-use + async def file(self, request, fil): + """Return a file.""" + assert isinstance(fil, str), 'only string paths allowed' + return web.FileResponse(fil) + + def register(self, router): + """Register the view with a router.""" + assert self.url is not None, 'No url set for view' + urls = [self.url] + self.extra_urls + + for method in ('get', 'post', 'delete', 'put'): + handler = getattr(self, method, None) + + if not handler: + continue + + handler = request_handler_factory(self, handler) + + for url in urls: + router.add_route(method, url, handler) + + # aiohttp_cors does not work with class based views + # self.app.router.add_route('*', self.url, self, name=self.name) + + # for url in self.extra_urls: + # self.app.router.add_route('*', url, self) + + +def request_handler_factory(view, handler): + """Wrap the handler classes.""" + assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \ + "Handler should be a coroutine or a callback." + + async def handle(request): + """Handle incoming request.""" + if not request.app['hass'].is_running: + return web.Response(status=503) + + authenticated = request.get(KEY_AUTHENTICATED, False) + + if view.requires_auth and not authenticated: + raise HTTPUnauthorized() + + _LOGGER.info('Serving %s to %s (auth: %s)', + request.path, request.get(KEY_REAL_IP), authenticated) + + result = handler(request, **request.match_info) + + if asyncio.iscoroutine(result): + result = await result + + if isinstance(result, web.StreamResponse): + # The method handler returned a ready-made Response, how nice of it + return result + + status_code = 200 + + if isinstance(result, tuple): + result, status_code = result + + if isinstance(result, str): + result = result.encode('utf-8') + elif result is None: + result = b'' + elif not isinstance(result, bytes): + assert False, ('Result should be None, string, bytes or Response. ' + 'Got: {}').format(result) + + return web.Response(body=result, status=status_code) + + return handle diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index b79812a8dce..47ef2c3eace 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -191,8 +191,7 @@ def result_message(iden, result=None): } -@asyncio.coroutine -def async_setup(hass, config): +async def async_setup(hass, config): """Initialize the websocket API.""" hass.http.register_view(WebsocketAPIView) return True @@ -205,11 +204,10 @@ class WebsocketAPIView(HomeAssistantView): url = URL requires_auth = False - @asyncio.coroutine - def get(self, request): + async def get(self, request): """Handle an incoming websocket connection.""" # pylint: disable=no-self-use - return ActiveConnection(request.app['hass'], request).handle() + return await ActiveConnection(request.app['hass'], request).handle() class ActiveConnection: @@ -233,17 +231,16 @@ class ActiveConnection: """Print an error message.""" _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2) - @asyncio.coroutine - def _writer(self): + async def _writer(self): """Write outgoing messages.""" # Exceptions if Socket disconnected or cancelled by connection handler with suppress(RuntimeError, *CANCELLATION_ERRORS): while not self.wsock.closed: - message = yield from self.to_write.get() + message = await self.to_write.get() if message is None: break self.debug("Sending", message) - yield from self.wsock.send_json(message, dumps=JSON_DUMP) + await self.wsock.send_json(message, dumps=JSON_DUMP) @callback def send_message_outside(self, message): @@ -266,12 +263,11 @@ class ActiveConnection: self._handle_task.cancel() self._writer_task.cancel() - @asyncio.coroutine - def handle(self): + async def handle(self): """Handle the websocket connection.""" request = self.request wsock = self.wsock = web.WebSocketResponse(heartbeat=55) - yield from wsock.prepare(request) + await wsock.prepare(request) self.debug("Connected") # Get a reference to current task so we can cancel our connection @@ -294,8 +290,8 @@ class ActiveConnection: authenticated = True else: - yield from self.wsock.send_json(auth_required_message()) - msg = yield from wsock.receive_json() + await self.wsock.send_json(auth_required_message()) + msg = await wsock.receive_json() msg = AUTH_MESSAGE_SCHEMA(msg) if validate_password(request, msg['api_password']): @@ -303,18 +299,18 @@ class ActiveConnection: else: self.debug("Invalid password") - yield from self.wsock.send_json( + await self.wsock.send_json( auth_invalid_message('Invalid password')) if not authenticated: - yield from process_wrong_login(request) + await process_wrong_login(request) return wsock - yield from self.wsock.send_json(auth_ok_message()) + await self.wsock.send_json(auth_ok_message()) # ---------- AUTH PHASE OVER ---------- - msg = yield from wsock.receive_json() + msg = await wsock.receive_json() last_id = 0 while msg: @@ -332,7 +328,7 @@ class ActiveConnection: getattr(self, handler_name)(msg) last_id = cur_id - msg = yield from wsock.receive_json() + msg = await wsock.receive_json() except vol.Invalid as err: error_msg = "Message incorrectly formatted: " @@ -394,11 +390,11 @@ class ActiveConnection: self.to_write.put_nowait(final_message) self.to_write.put_nowait(None) # Make sure all error messages are written before closing - yield from self._writer_task + await self._writer_task except asyncio.QueueFull: self._writer_task.cancel() - yield from wsock.close() + await wsock.close() self.debug("Closed connection") return wsock @@ -410,8 +406,7 @@ class ActiveConnection: """ msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg) - @asyncio.coroutine - def forward_events(event): + async def forward_events(event): """Forward events to websocket.""" if event.event_type == EVENT_TIME_CHANGED: return @@ -447,10 +442,9 @@ class ActiveConnection: """ msg = CALL_SERVICE_MESSAGE_SCHEMA(msg) - @asyncio.coroutine - def call_service_helper(msg): + async def call_service_helper(msg): """Call a service and fire complete message.""" - yield from self.hass.services.async_call( + await self.hass.services.async_call( msg['domain'], msg['service'], msg.get('service_data'), True) self.send_message_outside(result_message(msg['id'])) @@ -473,10 +467,9 @@ class ActiveConnection: """ msg = GET_SERVICES_MESSAGE_SCHEMA(msg) - @asyncio.coroutine - def get_services_helper(msg): + async def get_services_helper(msg): """Get available services and fire complete message.""" - descriptions = yield from async_get_all_descriptions(self.hass) + descriptions = await async_get_all_descriptions(self.hass) self.send_message_outside(result_message(msg['id'], descriptions)) self.hass.async_add_job(get_services_helper(msg)) diff --git a/tests/components/http/__init__.py b/tests/components/http/__init__.py index ef9817a2f1b..64f6c94c0da 100644 --- a/tests/components/http/__init__.py +++ b/tests/components/http/__init__.py @@ -1,5 +1,4 @@ """Tests for the HTTP component.""" -import asyncio from ipaddress import ip_address from aiohttp import web @@ -18,18 +17,16 @@ def mock_real_ip(app): nonlocal ip_to_mock ip_to_mock = value - @asyncio.coroutine @web.middleware - def mock_real_ip(request, handler): + async def mock_real_ip(request, handler): """Mock Real IP middleware.""" nonlocal ip_to_mock request[KEY_REAL_IP] = ip_address(ip_to_mock) - return (yield from handler(request)) + return (await handler(request)) - @asyncio.coroutine - def real_ip_startup(app): + async def real_ip_startup(app): """Startup of real ip.""" app.middlewares.insert(0, mock_real_ip) diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index c2687c05a8f..604ee9c0c9b 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -1,6 +1,5 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access -import asyncio from ipaddress import ip_network from unittest.mock import patch @@ -30,8 +29,7 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] -@asyncio.coroutine -def mock_handler(request): +async def mock_handler(request): """Return if request was authenticated.""" if not request[KEY_AUTHENTICATED]: raise HTTPUnauthorized @@ -47,84 +45,79 @@ def app(): return app -@asyncio.coroutine -def test_auth_middleware_loaded_by_default(hass): +async def test_auth_middleware_loaded_by_default(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_auth') as mock_setup: - yield from async_setup_component(hass, 'http', { + await async_setup_component(hass, 'http', { 'http': {} }) assert len(mock_setup.mock_calls) == 1 -@asyncio.coroutine -def test_access_without_password(app, test_client): +async def test_access_without_password(app, test_client): """Test access without password.""" setup_auth(app, [], None) - client = yield from test_client(app) + client = await test_client(app) - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 200 -@asyncio.coroutine -def test_access_with_password_in_header(app, test_client): +async def test_access_with_password_in_header(app, test_client): """Test access with password in URL.""" setup_auth(app, [], API_PASSWORD) - client = yield from test_client(app) + client = await test_client(app) - req = yield from client.get( + req = await client.get( '/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) assert req.status == 200 - req = yield from client.get( + req = await client.get( '/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'}) assert req.status == 401 -@asyncio.coroutine -def test_access_with_password_in_query(app, test_client): +async def test_access_with_password_in_query(app, test_client): """Test access without password.""" setup_auth(app, [], API_PASSWORD) - client = yield from test_client(app) + client = await test_client(app) - resp = yield from client.get('/', params={ + resp = await client.get('/', params={ 'api_password': API_PASSWORD }) assert resp.status == 200 - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 401 - resp = yield from client.get('/', params={ + resp = await client.get('/', params={ 'api_password': 'wrong-password' }) assert resp.status == 401 -@asyncio.coroutine -def test_basic_auth_works(app, test_client): +async def test_basic_auth_works(app, test_client): """Test access with basic authentication.""" setup_auth(app, [], API_PASSWORD) - client = yield from test_client(app) + client = await test_client(app) - req = yield from client.get( + req = await client.get( '/', auth=BasicAuth('homeassistant', API_PASSWORD)) assert req.status == 200 - req = yield from client.get( + req = await client.get( '/', auth=BasicAuth('wrong_username', API_PASSWORD)) assert req.status == 401 - req = yield from client.get( + req = await client.get( '/', auth=BasicAuth('homeassistant', 'wrong password')) assert req.status == 401 - req = yield from client.get( + req = await client.get( '/', headers={ 'authorization': 'NotBasic abcdefg' @@ -132,8 +125,7 @@ def test_basic_auth_works(app, test_client): assert req.status == 401 -@asyncio.coroutine -def test_access_with_trusted_ip(test_client): +async def test_access_with_trusted_ip(test_client): """Test access with an untrusted ip address.""" app = web.Application() app.router.add_get('/', mock_handler) @@ -141,16 +133,16 @@ def test_access_with_trusted_ip(test_client): setup_auth(app, TRUSTED_NETWORKS, 'some-pass') set_mock_ip = mock_real_ip(app) - client = yield from test_client(app) + client = await test_client(app) for remote_addr in UNTRUSTED_ADDRESSES: set_mock_ip(remote_addr) - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 401, \ "{} shouldn't be trusted".format(remote_addr) for remote_addr in TRUSTED_ADDRESSES: set_mock_ip(remote_addr) - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 200, \ "{} should be trusted".format(remote_addr) diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index bd6df4f4e73..2d7885d959f 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -1,6 +1,5 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access -import asyncio from unittest.mock import patch, mock_open from aiohttp import web @@ -16,8 +15,7 @@ from . import mock_real_ip BANNED_IPS = ['200.201.202.203', '100.64.0.2'] -@asyncio.coroutine -def test_access_from_banned_ip(hass, test_client): +async def test_access_from_banned_ip(hass, test_client): """Test accessing to server from banned IP. Both trusted and not.""" app = web.Application() setup_bans(hass, app, 5) @@ -26,19 +24,18 @@ def test_access_from_banned_ip(hass, test_client): with patch('homeassistant.components.http.ban.load_ip_bans_config', return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS]): - client = yield from test_client(app) + client = await test_client(app) for remote_addr in BANNED_IPS: set_real_ip(remote_addr) - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 403 -@asyncio.coroutine -def test_ban_middleware_not_loaded_by_config(hass): +async def test_ban_middleware_not_loaded_by_config(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_bans') as mock_setup: - yield from async_setup_component(hass, 'http', { + await async_setup_component(hass, 'http', { 'http': { http.CONF_IP_BAN_ENABLED: False, } @@ -47,25 +44,22 @@ def test_ban_middleware_not_loaded_by_config(hass): assert len(mock_setup.mock_calls) == 0 -@asyncio.coroutine -def test_ban_middleware_loaded_by_default(hass): +async def test_ban_middleware_loaded_by_default(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_bans') as mock_setup: - yield from async_setup_component(hass, 'http', { + await async_setup_component(hass, 'http', { 'http': {} }) assert len(mock_setup.mock_calls) == 1 -@asyncio.coroutine -def test_ip_bans_file_creation(hass, test_client): +async def test_ip_bans_file_creation(hass, test_client): """Testing if banned IP file created.""" app = web.Application() app['hass'] = hass - @asyncio.coroutine - def unauth_handler(request): + async def unauth_handler(request): """Return a mock web response.""" raise HTTPUnauthorized @@ -76,21 +70,21 @@ def test_ip_bans_file_creation(hass, test_client): with patch('homeassistant.components.http.ban.load_ip_bans_config', return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS]): - client = yield from test_client(app) + client = await test_client(app) m = mock_open() with patch('homeassistant.components.http.ban.open', m, create=True): - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 401 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) assert m.call_count == 0 - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 401 assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a') - resp = yield from client.get('/') + resp = await client.get('/') assert resp.status == 403 assert m.call_count == 1 diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py index 22b70e1c0c5..50464b36277 100644 --- a/tests/components/http/test_cors.py +++ b/tests/components/http/test_cors.py @@ -1,5 +1,4 @@ """Test cors for the HTTP component.""" -import asyncio from unittest.mock import patch from aiohttp import web @@ -20,22 +19,20 @@ from homeassistant.components.http.cors import setup_cors TRUSTED_ORIGIN = 'https://home-assistant.io' -@asyncio.coroutine -def test_cors_middleware_not_loaded_by_default(hass): +async def test_cors_middleware_not_loaded_by_default(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_cors') as mock_setup: - yield from async_setup_component(hass, 'http', { + await async_setup_component(hass, 'http', { 'http': {} }) assert len(mock_setup.mock_calls) == 0 -@asyncio.coroutine -def test_cors_middleware_loaded_from_config(hass): +async def test_cors_middleware_loaded_from_config(hass): """Test accessing to server from banned IP when feature is off.""" with patch('homeassistant.components.http.setup_cors') as mock_setup: - yield from async_setup_component(hass, 'http', { + await async_setup_component(hass, 'http', { 'http': { 'cors_allowed_origins': ['http://home-assistant.io'] } @@ -44,8 +41,7 @@ def test_cors_middleware_loaded_from_config(hass): assert len(mock_setup.mock_calls) == 1 -@asyncio.coroutine -def mock_handler(request): +async def mock_handler(request): """Return if request was authenticated.""" return web.Response(status=200) @@ -59,10 +55,9 @@ def client(loop, test_client): return loop.run_until_complete(test_client(app)) -@asyncio.coroutine -def test_cors_requests(client): +async def test_cors_requests(client): """Test cross origin requests.""" - req = yield from client.get('/', headers={ + req = await client.get('/', headers={ ORIGIN: TRUSTED_ORIGIN }) assert req.status == 200 @@ -70,7 +65,7 @@ def test_cors_requests(client): TRUSTED_ORIGIN # With password in URL - req = yield from client.get('/', params={ + req = await client.get('/', params={ 'api_password': 'some-pass' }, headers={ ORIGIN: TRUSTED_ORIGIN @@ -80,7 +75,7 @@ def test_cors_requests(client): TRUSTED_ORIGIN # With password in headers - req = yield from client.get('/', headers={ + req = await client.get('/', headers={ HTTP_HEADER_HA_AUTH: 'some-pass', ORIGIN: TRUSTED_ORIGIN }) @@ -89,10 +84,9 @@ def test_cors_requests(client): TRUSTED_ORIGIN -@asyncio.coroutine -def test_cors_preflight_allowed(client): +async def test_cors_preflight_allowed(client): """Test cross origin resource sharing preflight (OPTIONS) request.""" - req = yield from client.options('/', headers={ + req = await client.options('/', headers={ ORIGIN: TRUSTED_ORIGIN, ACCESS_CONTROL_REQUEST_METHOD: 'GET', ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access' diff --git a/tests/components/http/test_data_validator.py b/tests/components/http/test_data_validator.py index f00be4fc6f9..6cca1af8ccc 100644 --- a/tests/components/http/test_data_validator.py +++ b/tests/components/http/test_data_validator.py @@ -1,5 +1,4 @@ """Test data validator decorator.""" -import asyncio from unittest.mock import Mock from aiohttp import web @@ -9,8 +8,7 @@ from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.data_validator import RequestDataValidator -@asyncio.coroutine -def get_client(test_client, validator): +async def get_client(test_client, validator): """Generate a client that hits a view decorated with validator.""" app = web.Application() app['hass'] = Mock(is_running=True) @@ -20,58 +18,55 @@ def get_client(test_client, validator): name = 'test' requires_auth = False - @asyncio.coroutine @validator - def post(self, request, data): + async def post(self, request, data): """Test method.""" return b'' TestView().register(app.router) - client = yield from test_client(app) + client = await test_client(app) return client -@asyncio.coroutine -def test_validator(test_client): +async def test_validator(test_client): """Test the validator.""" - client = yield from get_client( + client = await get_client( test_client, RequestDataValidator(vol.Schema({ vol.Required('test'): str }))) - resp = yield from client.post('/', json={ + resp = await client.post('/', json={ 'test': 'bla' }) assert resp.status == 200 - resp = yield from client.post('/', json={ + resp = await client.post('/', json={ 'test': 100 }) assert resp.status == 400 - resp = yield from client.post('/') + resp = await client.post('/') assert resp.status == 400 -@asyncio.coroutine -def test_validator_allow_empty(test_client): +async def test_validator_allow_empty(test_client): """Test the validator with empty data.""" - client = yield from get_client( + client = await get_client( test_client, RequestDataValidator(vol.Schema({ # Although we allow empty, our schema should still be able # to validate an empty dict. vol.Optional('test'): str }), allow_empty=True)) - resp = yield from client.post('/', json={ + resp = await client.post('/', json={ 'test': 'bla' }) assert resp.status == 200 - resp = yield from client.post('/', json={ + resp = await client.post('/', json={ 'test': 100 }) assert resp.status == 400 - resp = yield from client.post('/') + resp = await client.post('/') assert resp.status == 200 diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py index ab06b48043e..1dcf45f48c3 100644 --- a/tests/components/http/test_init.py +++ b/tests/components/http/test_init.py @@ -1,6 +1,4 @@ """The tests for the Home Assistant HTTP component.""" -import asyncio - from homeassistant.setup import async_setup_component import homeassistant.components.http as http @@ -12,16 +10,14 @@ class TestView(http.HomeAssistantView): name = 'test' url = '/hello' - @asyncio.coroutine - def get(self, request): + async def get(self, request): """Return a get request.""" return 'hello' -@asyncio.coroutine -def test_registering_view_while_running(hass, test_client, unused_port): +async def test_registering_view_while_running(hass, test_client, unused_port): """Test that we can register a view while the server is running.""" - yield from async_setup_component( + await async_setup_component( hass, http.DOMAIN, { http.DOMAIN: { http.CONF_SERVER_PORT: unused_port(), @@ -29,15 +25,14 @@ def test_registering_view_while_running(hass, test_client, unused_port): } ) - yield from hass.async_start() + await hass.async_start() # This raises a RuntimeError if app is frozen hass.http.register_view(TestView) -@asyncio.coroutine -def test_api_base_url_with_domain(hass): +async def test_api_base_url_with_domain(hass): """Test setting API URL.""" - result = yield from async_setup_component(hass, 'http', { + result = await async_setup_component(hass, 'http', { 'http': { 'base_url': 'example.com' } @@ -46,10 +41,9 @@ def test_api_base_url_with_domain(hass): assert hass.config.api.base_url == 'http://example.com' -@asyncio.coroutine -def test_api_base_url_with_ip(hass): +async def test_api_base_url_with_ip(hass): """Test setting api url.""" - result = yield from async_setup_component(hass, 'http', { + result = await async_setup_component(hass, 'http', { 'http': { 'server_host': '1.1.1.1' } @@ -58,10 +52,9 @@ def test_api_base_url_with_ip(hass): assert hass.config.api.base_url == 'http://1.1.1.1:8123' -@asyncio.coroutine -def test_api_base_url_with_ip_port(hass): +async def test_api_base_url_with_ip_port(hass): """Test setting api url.""" - result = yield from async_setup_component(hass, 'http', { + result = await async_setup_component(hass, 'http', { 'http': { 'base_url': '1.1.1.1:8124' } @@ -70,10 +63,9 @@ def test_api_base_url_with_ip_port(hass): assert hass.config.api.base_url == 'http://1.1.1.1:8124' -@asyncio.coroutine -def test_api_no_base_url(hass): +async def test_api_no_base_url(hass): """Test setting api url.""" - result = yield from async_setup_component(hass, 'http', { + result = await async_setup_component(hass, 'http', { 'http': { } }) @@ -81,10 +73,9 @@ def test_api_no_base_url(hass): assert hass.config.api.base_url == 'http://127.0.0.1:8123' -@asyncio.coroutine -def test_not_log_password(hass, unused_port, test_client, caplog): +async def test_not_log_password(hass, unused_port, test_client, caplog): """Test access with password doesn't get logged.""" - result = yield from async_setup_component(hass, 'api', { + result = await async_setup_component(hass, 'api', { 'http': { http.CONF_SERVER_PORT: unused_port(), http.CONF_API_PASSWORD: 'some-pass' @@ -92,9 +83,9 @@ def test_not_log_password(hass, unused_port, test_client, caplog): }) assert result - client = yield from test_client(hass.http.app) + client = await test_client(hass.http.app) - resp = yield from client.get('/api/', params={ + resp = await client.get('/api/', params={ 'api_password': 'some-pass' }) diff --git a/tests/components/http/test_real_ip.py b/tests/components/http/test_real_ip.py index 90201ab4c10..3e4f9023537 100644 --- a/tests/components/http/test_real_ip.py +++ b/tests/components/http/test_real_ip.py @@ -1,6 +1,4 @@ """Test real IP middleware.""" -import asyncio - from aiohttp import web from aiohttp.hdrs import X_FORWARDED_FOR @@ -8,41 +6,38 @@ from homeassistant.components.http.real_ip import setup_real_ip from homeassistant.components.http.const import KEY_REAL_IP -@asyncio.coroutine -def mock_handler(request): +async def mock_handler(request): """Handler that returns the real IP as text.""" return web.Response(text=str(request[KEY_REAL_IP])) -@asyncio.coroutine -def test_ignore_x_forwarded_for(test_client): +async def test_ignore_x_forwarded_for(test_client): """Test that we get the IP from the transport.""" app = web.Application() app.router.add_get('/', mock_handler) setup_real_ip(app, False) - mock_api_client = yield from test_client(app) + mock_api_client = await test_client(app) - resp = yield from mock_api_client.get('/', headers={ + resp = await mock_api_client.get('/', headers={ X_FORWARDED_FOR: '255.255.255.255' }) assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert text != '255.255.255.255' -@asyncio.coroutine -def test_use_x_forwarded_for(test_client): +async def test_use_x_forwarded_for(test_client): """Test that we get the IP from the transport.""" app = web.Application() app.router.add_get('/', mock_handler) setup_real_ip(app, True) - mock_api_client = yield from test_client(app) + mock_api_client = await test_client(app) - resp = yield from mock_api_client.get('/', headers={ + resp = await mock_api_client.get('/', headers={ X_FORWARDED_FOR: '255.255.255.255' }) assert resp.status == 200 - text = yield from resp.text() + text = await resp.text() assert text == '255.255.255.255'