Move HomeAssistantView to separate file. Convert http to async syntax. [skip ci] (#12982)

* Move HomeAssistantView to separate file. Convert http to async syntax.

* pylint

* websocket api

* update emulated_hue for async/await

* Lint
pull/12996/head
Boyi C 2018-03-09 09:51:49 +08:00 committed by Paulus Schoutsen
parent 2ee73ca911
commit 321eb2ec6f
17 changed files with 292 additions and 344 deletions

View File

@ -4,7 +4,6 @@ Support for local control of entities by emulating the Phillips Hue bridge.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/emulated_hue/
"""
import asyncio
import logging
import voluptuous as vol
@ -111,17 +110,15 @@ def setup(hass, yaml_config):
config.upnp_bind_multicast, config.advertise_ip,
config.advertise_port)
@asyncio.coroutine
def stop_emulated_hue_bridge(event):
async def stop_emulated_hue_bridge(event):
"""Stop the emulated hue bridge."""
upnp_listener.stop()
yield from server.stop()
await server.stop()
@asyncio.coroutine
def start_emulated_hue_bridge(event):
async def start_emulated_hue_bridge(event):
"""Start the emulated hue bridge."""
upnp_listener.start()
yield from server.start()
await server.start()
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge)

View File

@ -4,21 +4,18 @@ This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
from ipaddress import ip_network
import json
import logging
import os
import ssl
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
from aiohttp.web_exceptions import HTTPMovedPermanently
import voluptuous as vol
from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
from homeassistant.core import is_callback
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVER_PORT)
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
import homeassistant.util as hass_util
@ -28,10 +25,13 @@ from .auth import setup_auth
from .ban import setup_bans
from .cors import setup_cors
from .real_ip import setup_real_ip
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
from .static import (
CachingFileResponse, CachingStaticResource, staticresource_middleware)
# Import as alias
from .const import KEY_AUTHENTICATED, KEY_REAL_IP # noqa
from .view import HomeAssistantView # noqa
REQUIREMENTS = ['aiohttp_cors==0.6.0']
DOMAIN = 'http'
@ -98,8 +98,7 @@ CONFIG_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA)
@asyncio.coroutine
def async_setup(hass, config):
async def async_setup(hass, config):
"""Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN)
@ -135,16 +134,14 @@ def async_setup(hass, config):
is_ban_enabled=is_ban_enabled
)
@asyncio.coroutine
def stop_server(event):
async def stop_server(event):
"""Stop the server."""
yield from server.stop()
await server.stop()
@asyncio.coroutine
def start_server(event):
async def start_server(event):
"""Start the server."""
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
yield from server.start()
await server.start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
@ -252,13 +249,11 @@ class HomeAssistantHTTP(object):
return
if cache_headers:
@asyncio.coroutine
def serve_file(request):
async def serve_file(request):
"""Serve file from disk."""
return CachingFileResponse(path)
else:
@asyncio.coroutine
def serve_file(request):
async def serve_file(request):
"""Serve file from disk."""
return web.FileResponse(path)
@ -276,14 +271,13 @@ class HomeAssistantHTTP(object):
self.app.router.add_route('GET', url_pattern, serve_file)
@asyncio.coroutine
def start(self):
async def start(self):
"""Start the WSGI server."""
# We misunderstood the startup signal. You're not allowed to change
# anything during startup. Temp workaround.
# pylint: disable=protected-access
self.app._on_startup.freeze()
yield from self.app.startup()
await self.app.startup()
if self.ssl_certificate:
try:
@ -308,121 +302,18 @@ class HomeAssistantHTTP(object):
self._handler = self.app.make_handler(loop=self.hass.loop)
try:
self.server = yield from self.hass.loop.create_server(
self.server = await self.hass.loop.create_server(
self._handler, self.server_host, self.server_port, ssl=context)
except OSError as error:
_LOGGER.error("Failed to create HTTP server at port %d: %s",
self.server_port, error)
@asyncio.coroutine
def stop(self):
async def stop(self):
"""Stop the WSGI server."""
if self.server:
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
await self.server.wait_closed()
await self.app.shutdown()
if self._handler:
yield from self._handler.shutdown(10)
yield from self.app.cleanup()
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200, headers=None):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
response = web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
headers=headers)
response.enable_compression()
return response
def json_message(self, message, status_code=200, message_code=None,
headers=None):
"""Return a JSON message response."""
data = {'message': message}
if message_code is not None:
data['code'] = message_code
return self.json(data, status_code, headers=headers)
@asyncio.coroutine
# pylint: disable=no-self-use
def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
return web.FileResponse(fil)
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
@asyncio.coroutine
def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = yield from result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle
await self._handler.shutdown(10)
await self.app.cleanup()

View File

@ -1,5 +1,5 @@
"""Authentication for HTTP component."""
import asyncio
import base64
import hmac
import logging
@ -20,13 +20,12 @@ _LOGGER = logging.getLogger(__name__)
def setup_auth(app, trusted_networks, api_password):
"""Create auth middleware for the app."""
@middleware
@asyncio.coroutine
def auth_middleware(request, handler):
async def auth_middleware(request, handler):
"""Authenticate as middleware."""
# If no password set, just always set authenticated=True
if api_password is None:
request[KEY_AUTHENTICATED] = True
return (yield from handler(request))
return await handler(request)
# Check authentication
authenticated = False
@ -50,10 +49,9 @@ def setup_auth(app, trusted_networks, api_password):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request))
return await handler(request)
@asyncio.coroutine
def auth_startup(app):
async def auth_startup(app):
"""Initialize auth middleware when app starts up."""
app.middlewares.append(auth_middleware)

View File

@ -1,5 +1,5 @@
"""Ban logic for HTTP component."""
import asyncio
from collections import defaultdict
from datetime import datetime
from ipaddress import ip_address
@ -38,11 +38,10 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
@callback
def setup_bans(hass, app, login_threshold):
"""Create IP Ban middleware for the app."""
@asyncio.coroutine
def ban_startup(app):
async def ban_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(ban_middleware)
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
app[KEY_BANNED_IPS] = await hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold
@ -51,12 +50,11 @@ def setup_bans(hass, app, login_threshold):
@middleware
@asyncio.coroutine
def ban_middleware(request, handler):
async def ban_middleware(request, handler):
"""IP Ban middleware."""
if KEY_BANNED_IPS not in request.app:
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
return (yield from handler(request))
return await handler(request)
# Verify if IP is not banned
ip_address_ = request[KEY_REAL_IP]
@ -67,14 +65,13 @@ def ban_middleware(request, handler):
raise HTTPForbidden()
try:
return (yield from handler(request))
return await handler(request)
except HTTPUnauthorized:
yield from process_wrong_login(request)
await process_wrong_login(request)
raise
@asyncio.coroutine
def process_wrong_login(request):
async def process_wrong_login(request):
"""Process a wrong login attempt."""
remote_addr = request[KEY_REAL_IP]
@ -98,7 +95,7 @@ def process_wrong_login(request):
request.app[KEY_BANNED_IPS].append(new_ban)
hass = request.app['hass']
yield from hass.async_add_job(
await hass.async_add_job(
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
_LOGGER.warning(

View File

@ -1,5 +1,5 @@
"""Provide cors support for the HTTP component."""
import asyncio
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
@ -27,8 +27,7 @@ def setup_cors(app, origins):
) for host in origins
})
@asyncio.coroutine
def cors_startup(app):
async def cors_startup(app):
"""Initialize cors when app starts up."""
cors_added = set()

View File

@ -1,5 +1,5 @@
"""Decorator for view methods to help with data validation."""
import asyncio
from functools import wraps
import logging
@ -24,16 +24,15 @@ class RequestDataValidator:
def __call__(self, method):
"""Decorate a function."""
@asyncio.coroutine
@wraps(method)
def wrapper(view, request, *args, **kwargs):
async def wrapper(view, request, *args, **kwargs):
"""Wrap a request handler with data validation."""
data = None
try:
data = yield from request.json()
data = await request.json()
except ValueError:
if not self._allow_empty or \
(yield from request.content.read()) != b'':
(await request.content.read()) != b'':
_LOGGER.error('Invalid JSON received.')
return view.json_message('Invalid JSON.', 400)
data = {}
@ -45,7 +44,7 @@ class RequestDataValidator:
return view.json_message(
'Message format incorrect: {}'.format(err), 400)
result = yield from method(view, request, *args, **kwargs)
result = await method(view, request, *args, **kwargs)
return result
return wrapper

View File

@ -1,5 +1,5 @@
"""Middleware to fetch real IP."""
import asyncio
from ipaddress import ip_address
from aiohttp.web import middleware
@ -14,8 +14,7 @@ from .const import KEY_REAL_IP
def setup_real_ip(app, use_x_forwarded_for):
"""Create IP Ban middleware for the app."""
@middleware
@asyncio.coroutine
def real_ip_middleware(request, handler):
async def real_ip_middleware(request, handler):
"""Real IP middleware."""
if (use_x_forwarded_for and
X_FORWARDED_FOR in request.headers):
@ -25,10 +24,9 @@ def setup_real_ip(app, use_x_forwarded_for):
request[KEY_REAL_IP] = \
ip_address(request.transport.get_extra_info('peername')[0])
return (yield from handler(request))
return await handler(request)
@asyncio.coroutine
def app_startup(app):
async def app_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(real_ip_middleware)

View File

@ -1,5 +1,5 @@
"""Static file handling for HTTP component."""
import asyncio
import re
from aiohttp import hdrs
@ -14,8 +14,7 @@ _FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers."""
@asyncio.coroutine
def _handle(self, request):
async def _handle(self, request):
filename = URL(request.match_info['filename']).path
try:
# PyLint is wrong about resolve not being a member.
@ -32,7 +31,7 @@ class CachingStaticResource(StaticResource):
raise HTTPNotFound() from error
if filepath.is_dir():
return (yield from super()._handle(request))
return await super()._handle(request)
elif filepath.is_file():
return CachingFileResponse(filepath, chunk_size=self._chunk_size)
else:
@ -49,26 +48,24 @@ class CachingFileResponse(FileResponse):
orig_sendfile = self._sendfile
@asyncio.coroutine
def sendfile(request, fobj, count):
async def sendfile(request, fobj, count):
"""Sendfile that includes a cache header."""
cache_time = 31 * 86400 # = 1 month
self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
cache_time)
yield from orig_sendfile(request, fobj, count)
await orig_sendfile(request, fobj, count)
# Overwriting like this because __init__ can change implementation.
self._sendfile = sendfile
@middleware
@asyncio.coroutine
def staticresource_middleware(request, handler):
async def staticresource_middleware(request, handler):
"""Middleware to strip out fingerprint from fingerprinted assets."""
path = request.path
if not path.startswith('/static/') and not path.startswith('/frontend'):
return (yield from handler(request))
return await handler(request)
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
@ -76,4 +73,4 @@ def staticresource_middleware(request, handler):
request.match_info['filename'] = \
'{}.{}'.format(*fingerprinted.groups())
return (yield from handler(request))
return await handler(request)

View File

@ -0,0 +1,121 @@
"""
This module provides WSGI application to serve the Home Assistant API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/
"""
import asyncio
import json
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
import homeassistant.remote as rem
from homeassistant.core import is_callback
from homeassistant.const import CONTENT_TYPE_JSON
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
_LOGGER = logging.getLogger(__name__)
class HomeAssistantView(object):
"""Base view for all views."""
url = None
extra_urls = []
requires_auth = True # Views inheriting from this class can override this
# pylint: disable=no-self-use
def json(self, result, status_code=200, headers=None):
"""Return a JSON response."""
msg = json.dumps(
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
response = web.Response(
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
headers=headers)
response.enable_compression()
return response
def json_message(self, message, status_code=200, message_code=None,
headers=None):
"""Return a JSON message response."""
data = {'message': message}
if message_code is not None:
data['code'] = message_code
return self.json(data, status_code, headers=headers)
# pylint: disable=no-self-use
async def file(self, request, fil):
"""Return a file."""
assert isinstance(fil, str), 'only string paths allowed'
return web.FileResponse(fil)
def register(self, router):
"""Register the view with a router."""
assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls
for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None)
if not handler:
continue
handler = request_handler_factory(self, handler)
for url in urls:
router.add_route(method, url, handler)
# aiohttp_cors does not work with class based views
# self.app.router.add_route('*', self.url, self, name=self.name)
# for url in self.extra_urls:
# self.app.router.add_route('*', url, self)
def request_handler_factory(view, handler):
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
"Handler should be a coroutine or a callback."
async def handle(request):
"""Handle incoming request."""
if not request.app['hass'].is_running:
return web.Response(status=503)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = await result
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = 200
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode('utf-8')
elif result is None:
result = b''
elif not isinstance(result, bytes):
assert False, ('Result should be None, string, bytes or Response. '
'Got: {}').format(result)
return web.Response(body=result, status=status_code)
return handle

View File

@ -191,8 +191,7 @@ def result_message(iden, result=None):
}
@asyncio.coroutine
def async_setup(hass, config):
async def async_setup(hass, config):
"""Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView)
return True
@ -205,11 +204,10 @@ class WebsocketAPIView(HomeAssistantView):
url = URL
requires_auth = False
@asyncio.coroutine
def get(self, request):
async def get(self, request):
"""Handle an incoming websocket connection."""
# pylint: disable=no-self-use
return ActiveConnection(request.app['hass'], request).handle()
return await ActiveConnection(request.app['hass'], request).handle()
class ActiveConnection:
@ -233,17 +231,16 @@ class ActiveConnection:
"""Print an error message."""
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
@asyncio.coroutine
def _writer(self):
async def _writer(self):
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
message = yield from self.to_write.get()
message = await self.to_write.get()
if message is None:
break
self.debug("Sending", message)
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
await self.wsock.send_json(message, dumps=JSON_DUMP)
@callback
def send_message_outside(self, message):
@ -266,12 +263,11 @@ class ActiveConnection:
self._handle_task.cancel()
self._writer_task.cancel()
@asyncio.coroutine
def handle(self):
async def handle(self):
"""Handle the websocket connection."""
request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
yield from wsock.prepare(request)
await wsock.prepare(request)
self.debug("Connected")
# Get a reference to current task so we can cancel our connection
@ -294,8 +290,8 @@ class ActiveConnection:
authenticated = True
else:
yield from self.wsock.send_json(auth_required_message())
msg = yield from wsock.receive_json()
await self.wsock.send_json(auth_required_message())
msg = await wsock.receive_json()
msg = AUTH_MESSAGE_SCHEMA(msg)
if validate_password(request, msg['api_password']):
@ -303,18 +299,18 @@ class ActiveConnection:
else:
self.debug("Invalid password")
yield from self.wsock.send_json(
await self.wsock.send_json(
auth_invalid_message('Invalid password'))
if not authenticated:
yield from process_wrong_login(request)
await process_wrong_login(request)
return wsock
yield from self.wsock.send_json(auth_ok_message())
await self.wsock.send_json(auth_ok_message())
# ---------- AUTH PHASE OVER ----------
msg = yield from wsock.receive_json()
msg = await wsock.receive_json()
last_id = 0
while msg:
@ -332,7 +328,7 @@ class ActiveConnection:
getattr(self, handler_name)(msg)
last_id = cur_id
msg = yield from wsock.receive_json()
msg = await wsock.receive_json()
except vol.Invalid as err:
error_msg = "Message incorrectly formatted: "
@ -394,11 +390,11 @@ class ActiveConnection:
self.to_write.put_nowait(final_message)
self.to_write.put_nowait(None)
# Make sure all error messages are written before closing
yield from self._writer_task
await self._writer_task
except asyncio.QueueFull:
self._writer_task.cancel()
yield from wsock.close()
await wsock.close()
self.debug("Closed connection")
return wsock
@ -410,8 +406,7 @@ class ActiveConnection:
"""
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
@asyncio.coroutine
def forward_events(event):
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
@ -447,10 +442,9 @@ class ActiveConnection:
"""
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
@asyncio.coroutine
def call_service_helper(msg):
async def call_service_helper(msg):
"""Call a service and fire complete message."""
yield from self.hass.services.async_call(
await self.hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), True)
self.send_message_outside(result_message(msg['id']))
@ -473,10 +467,9 @@ class ActiveConnection:
"""
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
@asyncio.coroutine
def get_services_helper(msg):
async def get_services_helper(msg):
"""Get available services and fire complete message."""
descriptions = yield from async_get_all_descriptions(self.hass)
descriptions = await async_get_all_descriptions(self.hass)
self.send_message_outside(result_message(msg['id'], descriptions))
self.hass.async_add_job(get_services_helper(msg))

View File

@ -1,5 +1,4 @@
"""Tests for the HTTP component."""
import asyncio
from ipaddress import ip_address
from aiohttp import web
@ -18,18 +17,16 @@ def mock_real_ip(app):
nonlocal ip_to_mock
ip_to_mock = value
@asyncio.coroutine
@web.middleware
def mock_real_ip(request, handler):
async def mock_real_ip(request, handler):
"""Mock Real IP middleware."""
nonlocal ip_to_mock
request[KEY_REAL_IP] = ip_address(ip_to_mock)
return (yield from handler(request))
return (await handler(request))
@asyncio.coroutine
def real_ip_startup(app):
async def real_ip_startup(app):
"""Startup of real ip."""
app.middlewares.insert(0, mock_real_ip)

View File

@ -1,6 +1,5 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import asyncio
from ipaddress import ip_network
from unittest.mock import patch
@ -30,8 +29,7 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
@asyncio.coroutine
def mock_handler(request):
async def mock_handler(request):
"""Return if request was authenticated."""
if not request[KEY_AUTHENTICATED]:
raise HTTPUnauthorized
@ -47,84 +45,79 @@ def app():
return app
@asyncio.coroutine
def test_auth_middleware_loaded_by_default(hass):
async def test_auth_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_auth') as mock_setup:
yield from async_setup_component(hass, 'http', {
await async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_access_without_password(app, test_client):
async def test_access_without_password(app, test_client):
"""Test access without password."""
setup_auth(app, [], None)
client = yield from test_client(app)
client = await test_client(app)
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 200
@asyncio.coroutine
def test_access_with_password_in_header(app, test_client):
async def test_access_with_password_in_header(app, test_client):
"""Test access with password in URL."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
client = await test_client(app)
req = yield from client.get(
req = await client.get(
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200
req = yield from client.get(
req = await client.get(
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
assert req.status == 401
@asyncio.coroutine
def test_access_with_password_in_query(app, test_client):
async def test_access_with_password_in_query(app, test_client):
"""Test access without password."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
client = await test_client(app)
resp = yield from client.get('/', params={
resp = await client.get('/', params={
'api_password': API_PASSWORD
})
assert resp.status == 200
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 401
resp = yield from client.get('/', params={
resp = await client.get('/', params={
'api_password': 'wrong-password'
})
assert resp.status == 401
@asyncio.coroutine
def test_basic_auth_works(app, test_client):
async def test_basic_auth_works(app, test_client):
"""Test access with basic authentication."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
client = await test_client(app)
req = yield from client.get(
req = await client.get(
'/',
auth=BasicAuth('homeassistant', API_PASSWORD))
assert req.status == 200
req = yield from client.get(
req = await client.get(
'/',
auth=BasicAuth('wrong_username', API_PASSWORD))
assert req.status == 401
req = yield from client.get(
req = await client.get(
'/',
auth=BasicAuth('homeassistant', 'wrong password'))
assert req.status == 401
req = yield from client.get(
req = await client.get(
'/',
headers={
'authorization': 'NotBasic abcdefg'
@ -132,8 +125,7 @@ def test_basic_auth_works(app, test_client):
assert req.status == 401
@asyncio.coroutine
def test_access_with_trusted_ip(test_client):
async def test_access_with_trusted_ip(test_client):
"""Test access with an untrusted ip address."""
app = web.Application()
app.router.add_get('/', mock_handler)
@ -141,16 +133,16 @@ def test_access_with_trusted_ip(test_client):
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
set_mock_ip = mock_real_ip(app)
client = yield from test_client(app)
client = await test_client(app)
for remote_addr in UNTRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
for remote_addr in TRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 200, \
"{} should be trusted".format(remote_addr)

View File

@ -1,6 +1,5 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import asyncio
from unittest.mock import patch, mock_open
from aiohttp import web
@ -16,8 +15,7 @@ from . import mock_real_ip
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
@asyncio.coroutine
def test_access_from_banned_ip(hass, test_client):
async def test_access_from_banned_ip(hass, test_client):
"""Test accessing to server from banned IP. Both trusted and not."""
app = web.Application()
setup_bans(hass, app, 5)
@ -26,19 +24,18 @@ def test_access_from_banned_ip(hass, test_client):
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
client = await test_client(app)
for remote_addr in BANNED_IPS:
set_real_ip(remote_addr)
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 403
@asyncio.coroutine
def test_ban_middleware_not_loaded_by_config(hass):
async def test_ban_middleware_not_loaded_by_config(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', {
await async_setup_component(hass, 'http', {
'http': {
http.CONF_IP_BAN_ENABLED: False,
}
@ -47,25 +44,22 @@ def test_ban_middleware_not_loaded_by_config(hass):
assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine
def test_ban_middleware_loaded_by_default(hass):
async def test_ban_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', {
await async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_ip_bans_file_creation(hass, test_client):
async def test_ip_bans_file_creation(hass, test_client):
"""Testing if banned IP file created."""
app = web.Application()
app['hass'] = hass
@asyncio.coroutine
def unauth_handler(request):
async def unauth_handler(request):
"""Return a mock web response."""
raise HTTPUnauthorized
@ -76,21 +70,21 @@ def test_ip_bans_file_creation(hass, test_client):
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
client = await test_client(app)
m = mock_open()
with patch('homeassistant.components.http.ban.open', m, create=True):
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 401
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 401
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
resp = yield from client.get('/')
resp = await client.get('/')
assert resp.status == 403
assert m.call_count == 1

View File

@ -1,5 +1,4 @@
"""Test cors for the HTTP component."""
import asyncio
from unittest.mock import patch
from aiohttp import web
@ -20,22 +19,20 @@ from homeassistant.components.http.cors import setup_cors
TRUSTED_ORIGIN = 'https://home-assistant.io'
@asyncio.coroutine
def test_cors_middleware_not_loaded_by_default(hass):
async def test_cors_middleware_not_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
await async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine
def test_cors_middleware_loaded_from_config(hass):
async def test_cors_middleware_loaded_from_config(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
await async_setup_component(hass, 'http', {
'http': {
'cors_allowed_origins': ['http://home-assistant.io']
}
@ -44,8 +41,7 @@ def test_cors_middleware_loaded_from_config(hass):
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def mock_handler(request):
async def mock_handler(request):
"""Return if request was authenticated."""
return web.Response(status=200)
@ -59,10 +55,9 @@ def client(loop, test_client):
return loop.run_until_complete(test_client(app))
@asyncio.coroutine
def test_cors_requests(client):
async def test_cors_requests(client):
"""Test cross origin requests."""
req = yield from client.get('/', headers={
req = await client.get('/', headers={
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
@ -70,7 +65,7 @@ def test_cors_requests(client):
TRUSTED_ORIGIN
# With password in URL
req = yield from client.get('/', params={
req = await client.get('/', params={
'api_password': 'some-pass'
}, headers={
ORIGIN: TRUSTED_ORIGIN
@ -80,7 +75,7 @@ def test_cors_requests(client):
TRUSTED_ORIGIN
# With password in headers
req = yield from client.get('/', headers={
req = await client.get('/', headers={
HTTP_HEADER_HA_AUTH: 'some-pass',
ORIGIN: TRUSTED_ORIGIN
})
@ -89,10 +84,9 @@ def test_cors_requests(client):
TRUSTED_ORIGIN
@asyncio.coroutine
def test_cors_preflight_allowed(client):
async def test_cors_preflight_allowed(client):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
req = yield from client.options('/', headers={
req = await client.options('/', headers={
ORIGIN: TRUSTED_ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'

View File

@ -1,5 +1,4 @@
"""Test data validator decorator."""
import asyncio
from unittest.mock import Mock
from aiohttp import web
@ -9,8 +8,7 @@ from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
@asyncio.coroutine
def get_client(test_client, validator):
async def get_client(test_client, validator):
"""Generate a client that hits a view decorated with validator."""
app = web.Application()
app['hass'] = Mock(is_running=True)
@ -20,58 +18,55 @@ def get_client(test_client, validator):
name = 'test'
requires_auth = False
@asyncio.coroutine
@validator
def post(self, request, data):
async def post(self, request, data):
"""Test method."""
return b''
TestView().register(app.router)
client = yield from test_client(app)
client = await test_client(app)
return client
@asyncio.coroutine
def test_validator(test_client):
async def test_validator(test_client):
"""Test the validator."""
client = yield from get_client(
client = await get_client(
test_client, RequestDataValidator(vol.Schema({
vol.Required('test'): str
})))
resp = yield from client.post('/', json={
resp = await client.post('/', json={
'test': 'bla'
})
assert resp.status == 200
resp = yield from client.post('/', json={
resp = await client.post('/', json={
'test': 100
})
assert resp.status == 400
resp = yield from client.post('/')
resp = await client.post('/')
assert resp.status == 400
@asyncio.coroutine
def test_validator_allow_empty(test_client):
async def test_validator_allow_empty(test_client):
"""Test the validator with empty data."""
client = yield from get_client(
client = await get_client(
test_client, RequestDataValidator(vol.Schema({
# Although we allow empty, our schema should still be able
# to validate an empty dict.
vol.Optional('test'): str
}), allow_empty=True))
resp = yield from client.post('/', json={
resp = await client.post('/', json={
'test': 'bla'
})
assert resp.status == 200
resp = yield from client.post('/', json={
resp = await client.post('/', json={
'test': 100
})
assert resp.status == 400
resp = yield from client.post('/')
resp = await client.post('/')
assert resp.status == 200

View File

@ -1,6 +1,4 @@
"""The tests for the Home Assistant HTTP component."""
import asyncio
from homeassistant.setup import async_setup_component
import homeassistant.components.http as http
@ -12,16 +10,14 @@ class TestView(http.HomeAssistantView):
name = 'test'
url = '/hello'
@asyncio.coroutine
def get(self, request):
async def get(self, request):
"""Return a get request."""
return 'hello'
@asyncio.coroutine
def test_registering_view_while_running(hass, test_client, unused_port):
async def test_registering_view_while_running(hass, test_client, unused_port):
"""Test that we can register a view while the server is running."""
yield from async_setup_component(
await async_setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_SERVER_PORT: unused_port(),
@ -29,15 +25,14 @@ def test_registering_view_while_running(hass, test_client, unused_port):
}
)
yield from hass.async_start()
await hass.async_start()
# This raises a RuntimeError if app is frozen
hass.http.register_view(TestView)
@asyncio.coroutine
def test_api_base_url_with_domain(hass):
async def test_api_base_url_with_domain(hass):
"""Test setting API URL."""
result = yield from async_setup_component(hass, 'http', {
result = await async_setup_component(hass, 'http', {
'http': {
'base_url': 'example.com'
}
@ -46,10 +41,9 @@ def test_api_base_url_with_domain(hass):
assert hass.config.api.base_url == 'http://example.com'
@asyncio.coroutine
def test_api_base_url_with_ip(hass):
async def test_api_base_url_with_ip(hass):
"""Test setting api url."""
result = yield from async_setup_component(hass, 'http', {
result = await async_setup_component(hass, 'http', {
'http': {
'server_host': '1.1.1.1'
}
@ -58,10 +52,9 @@ def test_api_base_url_with_ip(hass):
assert hass.config.api.base_url == 'http://1.1.1.1:8123'
@asyncio.coroutine
def test_api_base_url_with_ip_port(hass):
async def test_api_base_url_with_ip_port(hass):
"""Test setting api url."""
result = yield from async_setup_component(hass, 'http', {
result = await async_setup_component(hass, 'http', {
'http': {
'base_url': '1.1.1.1:8124'
}
@ -70,10 +63,9 @@ def test_api_base_url_with_ip_port(hass):
assert hass.config.api.base_url == 'http://1.1.1.1:8124'
@asyncio.coroutine
def test_api_no_base_url(hass):
async def test_api_no_base_url(hass):
"""Test setting api url."""
result = yield from async_setup_component(hass, 'http', {
result = await async_setup_component(hass, 'http', {
'http': {
}
})
@ -81,10 +73,9 @@ def test_api_no_base_url(hass):
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
@asyncio.coroutine
def test_not_log_password(hass, unused_port, test_client, caplog):
async def test_not_log_password(hass, unused_port, test_client, caplog):
"""Test access with password doesn't get logged."""
result = yield from async_setup_component(hass, 'api', {
result = await async_setup_component(hass, 'api', {
'http': {
http.CONF_SERVER_PORT: unused_port(),
http.CONF_API_PASSWORD: 'some-pass'
@ -92,9 +83,9 @@ def test_not_log_password(hass, unused_port, test_client, caplog):
})
assert result
client = yield from test_client(hass.http.app)
client = await test_client(hass.http.app)
resp = yield from client.get('/api/', params={
resp = await client.get('/api/', params={
'api_password': 'some-pass'
})

View File

@ -1,6 +1,4 @@
"""Test real IP middleware."""
import asyncio
from aiohttp import web
from aiohttp.hdrs import X_FORWARDED_FOR
@ -8,41 +6,38 @@ from homeassistant.components.http.real_ip import setup_real_ip
from homeassistant.components.http.const import KEY_REAL_IP
@asyncio.coroutine
def mock_handler(request):
async def mock_handler(request):
"""Handler that returns the real IP as text."""
return web.Response(text=str(request[KEY_REAL_IP]))
@asyncio.coroutine
def test_ignore_x_forwarded_for(test_client):
async def test_ignore_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, False)
mock_api_client = yield from test_client(app)
mock_api_client = await test_client(app)
resp = yield from mock_api_client.get('/', headers={
resp = await mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
text = await resp.text()
assert text != '255.255.255.255'
@asyncio.coroutine
def test_use_x_forwarded_for(test_client):
async def test_use_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, True)
mock_api_client = yield from test_client(app)
mock_api_client = await test_client(app)
resp = yield from mock_api_client.get('/', headers={
resp = await mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
text = await resp.text()
assert text == '255.255.255.255'