Move HomeAssistantView to separate file. Convert http to async syntax. [skip ci] (#12982)
* Move HomeAssistantView to separate file. Convert http to async syntax. * pylint * websocket api * update emulated_hue for async/await * Lintpull/12996/head
parent
2ee73ca911
commit
321eb2ec6f
|
@ -4,7 +4,6 @@ Support for local control of entities by emulating the Phillips Hue bridge.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/emulated_hue/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -111,17 +110,15 @@ def setup(hass, yaml_config):
|
|||
config.upnp_bind_multicast, config.advertise_ip,
|
||||
config.advertise_port)
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop_emulated_hue_bridge(event):
|
||||
async def stop_emulated_hue_bridge(event):
|
||||
"""Stop the emulated hue bridge."""
|
||||
upnp_listener.stop()
|
||||
yield from server.stop()
|
||||
await server.stop()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start_emulated_hue_bridge(event):
|
||||
async def start_emulated_hue_bridge(event):
|
||||
"""Start the emulated hue bridge."""
|
||||
upnp_listener.start()
|
||||
yield from server.start()
|
||||
await server.start()
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, stop_emulated_hue_bridge)
|
||||
|
||||
|
|
|
@ -4,21 +4,18 @@ This module provides WSGI application to serve the Home Assistant API.
|
|||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/http/
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from ipaddress import ip_network
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
||||
from aiohttp.web_exceptions import HTTPMovedPermanently
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
SERVER_PORT, CONTENT_TYPE_JSON,
|
||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
|
||||
from homeassistant.core import is_callback
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, SERVER_PORT)
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.remote as rem
|
||||
import homeassistant.util as hass_util
|
||||
|
@ -28,10 +25,13 @@ from .auth import setup_auth
|
|||
from .ban import setup_bans
|
||||
from .cors import setup_cors
|
||||
from .real_ip import setup_real_ip
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||
from .static import (
|
||||
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
||||
|
||||
# Import as alias
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP # noqa
|
||||
from .view import HomeAssistantView # noqa
|
||||
|
||||
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
||||
|
||||
DOMAIN = 'http'
|
||||
|
@ -98,8 +98,7 @@ CONFIG_SCHEMA = vol.Schema({
|
|||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the HTTP API and debug interface."""
|
||||
conf = config.get(DOMAIN)
|
||||
|
||||
|
@ -135,16 +134,14 @@ def async_setup(hass, config):
|
|||
is_ban_enabled=is_ban_enabled
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop_server(event):
|
||||
async def stop_server(event):
|
||||
"""Stop the server."""
|
||||
yield from server.stop()
|
||||
await server.stop()
|
||||
|
||||
@asyncio.coroutine
|
||||
def start_server(event):
|
||||
async def start_server(event):
|
||||
"""Start the server."""
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
||||
yield from server.start()
|
||||
await server.start()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server)
|
||||
|
||||
|
@ -252,13 +249,11 @@ class HomeAssistantHTTP(object):
|
|||
return
|
||||
|
||||
if cache_headers:
|
||||
@asyncio.coroutine
|
||||
def serve_file(request):
|
||||
async def serve_file(request):
|
||||
"""Serve file from disk."""
|
||||
return CachingFileResponse(path)
|
||||
else:
|
||||
@asyncio.coroutine
|
||||
def serve_file(request):
|
||||
async def serve_file(request):
|
||||
"""Serve file from disk."""
|
||||
return web.FileResponse(path)
|
||||
|
||||
|
@ -276,14 +271,13 @@ class HomeAssistantHTTP(object):
|
|||
|
||||
self.app.router.add_route('GET', url_pattern, serve_file)
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
async def start(self):
|
||||
"""Start the WSGI server."""
|
||||
# We misunderstood the startup signal. You're not allowed to change
|
||||
# anything during startup. Temp workaround.
|
||||
# pylint: disable=protected-access
|
||||
self.app._on_startup.freeze()
|
||||
yield from self.app.startup()
|
||||
await self.app.startup()
|
||||
|
||||
if self.ssl_certificate:
|
||||
try:
|
||||
|
@ -308,121 +302,18 @@ class HomeAssistantHTTP(object):
|
|||
self._handler = self.app.make_handler(loop=self.hass.loop)
|
||||
|
||||
try:
|
||||
self.server = yield from self.hass.loop.create_server(
|
||||
self.server = await self.hass.loop.create_server(
|
||||
self._handler, self.server_host, self.server_port, ssl=context)
|
||||
except OSError as error:
|
||||
_LOGGER.error("Failed to create HTTP server at port %d: %s",
|
||||
self.server_port, error)
|
||||
|
||||
@asyncio.coroutine
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
"""Stop the WSGI server."""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
yield from self.server.wait_closed()
|
||||
yield from self.app.shutdown()
|
||||
await self.server.wait_closed()
|
||||
await self.app.shutdown()
|
||||
if self._handler:
|
||||
yield from self._handler.shutdown(10)
|
||||
yield from self.app.cleanup()
|
||||
|
||||
|
||||
class HomeAssistantView(object):
|
||||
"""Base view for all views."""
|
||||
|
||||
url = None
|
||||
extra_urls = []
|
||||
requires_auth = True # Views inheriting from this class can override this
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def json(self, result, status_code=200, headers=None):
|
||||
"""Return a JSON response."""
|
||||
msg = json.dumps(
|
||||
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
||||
response = web.Response(
|
||||
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
|
||||
headers=headers)
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
def json_message(self, message, status_code=200, message_code=None,
|
||||
headers=None):
|
||||
"""Return a JSON message response."""
|
||||
data = {'message': message}
|
||||
if message_code is not None:
|
||||
data['code'] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
@asyncio.coroutine
|
||||
# pylint: disable=no-self-use
|
||||
def file(self, request, fil):
|
||||
"""Return a file."""
|
||||
assert isinstance(fil, str), 'only string paths allowed'
|
||||
return web.FileResponse(fil)
|
||||
|
||||
def register(self, router):
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, 'No url set for view'
|
||||
urls = [self.url] + self.extra_urls
|
||||
|
||||
for method in ('get', 'post', 'delete', 'put'):
|
||||
handler = getattr(self, method, None)
|
||||
|
||||
if not handler:
|
||||
continue
|
||||
|
||||
handler = request_handler_factory(self, handler)
|
||||
|
||||
for url in urls:
|
||||
router.add_route(method, url, handler)
|
||||
|
||||
# aiohttp_cors does not work with class based views
|
||||
# self.app.router.add_route('*', self.url, self, name=self.name)
|
||||
|
||||
# for url in self.extra_urls:
|
||||
# self.app.router.add_route('*', url, self)
|
||||
|
||||
|
||||
def request_handler_factory(view, handler):
|
||||
"""Wrap the handler classes."""
|
||||
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
||||
"Handler should be a coroutine or a callback."
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle(request):
|
||||
"""Handle incoming request."""
|
||||
if not request.app['hass'].is_running:
|
||||
return web.Response(status=503)
|
||||
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||
|
||||
result = handler(request, **request.match_info)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = yield from result
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
return result
|
||||
|
||||
status_code = 200
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, str):
|
||||
result = result.encode('utf-8')
|
||||
elif result is None:
|
||||
result = b''
|
||||
elif not isinstance(result, bytes):
|
||||
assert False, ('Result should be None, string, bytes or Response. '
|
||||
'Got: {}').format(result)
|
||||
|
||||
return web.Response(body=result, status=status_code)
|
||||
|
||||
return handle
|
||||
await self._handler.shutdown(10)
|
||||
await self.app.cleanup()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Authentication for HTTP component."""
|
||||
import asyncio
|
||||
|
||||
import base64
|
||||
import hmac
|
||||
import logging
|
||||
|
@ -20,13 +20,12 @@ _LOGGER = logging.getLogger(__name__)
|
|||
def setup_auth(app, trusted_networks, api_password):
|
||||
"""Create auth middleware for the app."""
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def auth_middleware(request, handler):
|
||||
async def auth_middleware(request, handler):
|
||||
"""Authenticate as middleware."""
|
||||
# If no password set, just always set authenticated=True
|
||||
if api_password is None:
|
||||
request[KEY_AUTHENTICATED] = True
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
||||
# Check authentication
|
||||
authenticated = False
|
||||
|
@ -50,10 +49,9 @@ def setup_auth(app, trusted_networks, api_password):
|
|||
authenticated = True
|
||||
|
||||
request[KEY_AUTHENTICATED] = authenticated
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
||||
@asyncio.coroutine
|
||||
def auth_startup(app):
|
||||
async def auth_startup(app):
|
||||
"""Initialize auth middleware when app starts up."""
|
||||
app.middlewares.append(auth_middleware)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Ban logic for HTTP component."""
|
||||
import asyncio
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from ipaddress import ip_address
|
||||
|
@ -38,11 +38,10 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
|
|||
@callback
|
||||
def setup_bans(hass, app, login_threshold):
|
||||
"""Create IP Ban middleware for the app."""
|
||||
@asyncio.coroutine
|
||||
def ban_startup(app):
|
||||
async def ban_startup(app):
|
||||
"""Initialize bans when app starts up."""
|
||||
app.middlewares.append(ban_middleware)
|
||||
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
||||
app[KEY_BANNED_IPS] = await hass.async_add_job(
|
||||
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||
|
@ -51,12 +50,11 @@ def setup_bans(hass, app, login_threshold):
|
|||
|
||||
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def ban_middleware(request, handler):
|
||||
async def ban_middleware(request, handler):
|
||||
"""IP Ban middleware."""
|
||||
if KEY_BANNED_IPS not in request.app:
|
||||
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
||||
# Verify if IP is not banned
|
||||
ip_address_ = request[KEY_REAL_IP]
|
||||
|
@ -67,14 +65,13 @@ def ban_middleware(request, handler):
|
|||
raise HTTPForbidden()
|
||||
|
||||
try:
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
except HTTPUnauthorized:
|
||||
yield from process_wrong_login(request)
|
||||
await process_wrong_login(request)
|
||||
raise
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def process_wrong_login(request):
|
||||
async def process_wrong_login(request):
|
||||
"""Process a wrong login attempt."""
|
||||
remote_addr = request[KEY_REAL_IP]
|
||||
|
||||
|
@ -98,7 +95,7 @@ def process_wrong_login(request):
|
|||
request.app[KEY_BANNED_IPS].append(new_ban)
|
||||
|
||||
hass = request.app['hass']
|
||||
yield from hass.async_add_job(
|
||||
await hass.async_add_job(
|
||||
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
|
||||
|
||||
_LOGGER.warning(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Provide cors support for the HTTP component."""
|
||||
import asyncio
|
||||
|
||||
|
||||
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
||||
|
||||
|
@ -27,8 +27,7 @@ def setup_cors(app, origins):
|
|||
) for host in origins
|
||||
})
|
||||
|
||||
@asyncio.coroutine
|
||||
def cors_startup(app):
|
||||
async def cors_startup(app):
|
||||
"""Initialize cors when app starts up."""
|
||||
cors_added = set()
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Decorator for view methods to help with data validation."""
|
||||
import asyncio
|
||||
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
|
@ -24,16 +24,15 @@ class RequestDataValidator:
|
|||
|
||||
def __call__(self, method):
|
||||
"""Decorate a function."""
|
||||
@asyncio.coroutine
|
||||
@wraps(method)
|
||||
def wrapper(view, request, *args, **kwargs):
|
||||
async def wrapper(view, request, *args, **kwargs):
|
||||
"""Wrap a request handler with data validation."""
|
||||
data = None
|
||||
try:
|
||||
data = yield from request.json()
|
||||
data = await request.json()
|
||||
except ValueError:
|
||||
if not self._allow_empty or \
|
||||
(yield from request.content.read()) != b'':
|
||||
(await request.content.read()) != b'':
|
||||
_LOGGER.error('Invalid JSON received.')
|
||||
return view.json_message('Invalid JSON.', 400)
|
||||
data = {}
|
||||
|
@ -45,7 +44,7 @@ class RequestDataValidator:
|
|||
return view.json_message(
|
||||
'Message format incorrect: {}'.format(err), 400)
|
||||
|
||||
result = yield from method(view, request, *args, **kwargs)
|
||||
result = await method(view, request, *args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Middleware to fetch real IP."""
|
||||
import asyncio
|
||||
|
||||
from ipaddress import ip_address
|
||||
|
||||
from aiohttp.web import middleware
|
||||
|
@ -14,8 +14,7 @@ from .const import KEY_REAL_IP
|
|||
def setup_real_ip(app, use_x_forwarded_for):
|
||||
"""Create IP Ban middleware for the app."""
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def real_ip_middleware(request, handler):
|
||||
async def real_ip_middleware(request, handler):
|
||||
"""Real IP middleware."""
|
||||
if (use_x_forwarded_for and
|
||||
X_FORWARDED_FOR in request.headers):
|
||||
|
@ -25,10 +24,9 @@ def setup_real_ip(app, use_x_forwarded_for):
|
|||
request[KEY_REAL_IP] = \
|
||||
ip_address(request.transport.get_extra_info('peername')[0])
|
||||
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
||||
@asyncio.coroutine
|
||||
def app_startup(app):
|
||||
async def app_startup(app):
|
||||
"""Initialize bans when app starts up."""
|
||||
app.middlewares.append(real_ip_middleware)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Static file handling for HTTP component."""
|
||||
import asyncio
|
||||
|
||||
import re
|
||||
|
||||
from aiohttp import hdrs
|
||||
|
@ -14,8 +14,7 @@ _FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE)
|
|||
class CachingStaticResource(StaticResource):
|
||||
"""Static Resource handler that will add cache headers."""
|
||||
|
||||
@asyncio.coroutine
|
||||
def _handle(self, request):
|
||||
async def _handle(self, request):
|
||||
filename = URL(request.match_info['filename']).path
|
||||
try:
|
||||
# PyLint is wrong about resolve not being a member.
|
||||
|
@ -32,7 +31,7 @@ class CachingStaticResource(StaticResource):
|
|||
raise HTTPNotFound() from error
|
||||
|
||||
if filepath.is_dir():
|
||||
return (yield from super()._handle(request))
|
||||
return await super()._handle(request)
|
||||
elif filepath.is_file():
|
||||
return CachingFileResponse(filepath, chunk_size=self._chunk_size)
|
||||
else:
|
||||
|
@ -49,26 +48,24 @@ class CachingFileResponse(FileResponse):
|
|||
|
||||
orig_sendfile = self._sendfile
|
||||
|
||||
@asyncio.coroutine
|
||||
def sendfile(request, fobj, count):
|
||||
async def sendfile(request, fobj, count):
|
||||
"""Sendfile that includes a cache header."""
|
||||
cache_time = 31 * 86400 # = 1 month
|
||||
self.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format(
|
||||
cache_time)
|
||||
|
||||
yield from orig_sendfile(request, fobj, count)
|
||||
await orig_sendfile(request, fobj, count)
|
||||
|
||||
# Overwriting like this because __init__ can change implementation.
|
||||
self._sendfile = sendfile
|
||||
|
||||
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def staticresource_middleware(request, handler):
|
||||
async def staticresource_middleware(request, handler):
|
||||
"""Middleware to strip out fingerprint from fingerprinted assets."""
|
||||
path = request.path
|
||||
if not path.startswith('/static/') and not path.startswith('/frontend'):
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
||||
fingerprinted = _FINGERPRINT.match(request.match_info['filename'])
|
||||
|
||||
|
@ -76,4 +73,4 @@ def staticresource_middleware(request, handler):
|
|||
request.match_info['filename'] = \
|
||||
'{}.{}'.format(*fingerprinted.groups())
|
||||
|
||||
return (yield from handler(request))
|
||||
return await handler(request)
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
This module provides WSGI application to serve the Home Assistant API.
|
||||
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/components/http/
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
|
||||
import homeassistant.remote as rem
|
||||
from homeassistant.core import is_callback
|
||||
from homeassistant.const import CONTENT_TYPE_JSON
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HomeAssistantView(object):
|
||||
"""Base view for all views."""
|
||||
|
||||
url = None
|
||||
extra_urls = []
|
||||
requires_auth = True # Views inheriting from this class can override this
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def json(self, result, status_code=200, headers=None):
|
||||
"""Return a JSON response."""
|
||||
msg = json.dumps(
|
||||
result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8')
|
||||
response = web.Response(
|
||||
body=msg, content_type=CONTENT_TYPE_JSON, status=status_code,
|
||||
headers=headers)
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
def json_message(self, message, status_code=200, message_code=None,
|
||||
headers=None):
|
||||
"""Return a JSON message response."""
|
||||
data = {'message': message}
|
||||
if message_code is not None:
|
||||
data['code'] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
async def file(self, request, fil):
|
||||
"""Return a file."""
|
||||
assert isinstance(fil, str), 'only string paths allowed'
|
||||
return web.FileResponse(fil)
|
||||
|
||||
def register(self, router):
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, 'No url set for view'
|
||||
urls = [self.url] + self.extra_urls
|
||||
|
||||
for method in ('get', 'post', 'delete', 'put'):
|
||||
handler = getattr(self, method, None)
|
||||
|
||||
if not handler:
|
||||
continue
|
||||
|
||||
handler = request_handler_factory(self, handler)
|
||||
|
||||
for url in urls:
|
||||
router.add_route(method, url, handler)
|
||||
|
||||
# aiohttp_cors does not work with class based views
|
||||
# self.app.router.add_route('*', self.url, self, name=self.name)
|
||||
|
||||
# for url in self.extra_urls:
|
||||
# self.app.router.add_route('*', url, self)
|
||||
|
||||
|
||||
def request_handler_factory(view, handler):
|
||||
"""Wrap the handler classes."""
|
||||
assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \
|
||||
"Handler should be a coroutine or a callback."
|
||||
|
||||
async def handle(request):
|
||||
"""Handle incoming request."""
|
||||
if not request.app['hass'].is_running:
|
||||
return web.Response(status=503)
|
||||
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||
|
||||
result = handler(request, **request.match_info)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
return result
|
||||
|
||||
status_code = 200
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, str):
|
||||
result = result.encode('utf-8')
|
||||
elif result is None:
|
||||
result = b''
|
||||
elif not isinstance(result, bytes):
|
||||
assert False, ('Result should be None, string, bytes or Response. '
|
||||
'Got: {}').format(result)
|
||||
|
||||
return web.Response(body=result, status=status_code)
|
||||
|
||||
return handle
|
|
@ -191,8 +191,7 @@ def result_message(iden, result=None):
|
|||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
async def async_setup(hass, config):
|
||||
"""Initialize the websocket API."""
|
||||
hass.http.register_view(WebsocketAPIView)
|
||||
return True
|
||||
|
@ -205,11 +204,10 @@ class WebsocketAPIView(HomeAssistantView):
|
|||
url = URL
|
||||
requires_auth = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
"""Handle an incoming websocket connection."""
|
||||
# pylint: disable=no-self-use
|
||||
return ActiveConnection(request.app['hass'], request).handle()
|
||||
return await ActiveConnection(request.app['hass'], request).handle()
|
||||
|
||||
|
||||
class ActiveConnection:
|
||||
|
@ -233,17 +231,16 @@ class ActiveConnection:
|
|||
"""Print an error message."""
|
||||
_LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _writer(self):
|
||||
async def _writer(self):
|
||||
"""Write outgoing messages."""
|
||||
# Exceptions if Socket disconnected or cancelled by connection handler
|
||||
with suppress(RuntimeError, *CANCELLATION_ERRORS):
|
||||
while not self.wsock.closed:
|
||||
message = yield from self.to_write.get()
|
||||
message = await self.to_write.get()
|
||||
if message is None:
|
||||
break
|
||||
self.debug("Sending", message)
|
||||
yield from self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||
await self.wsock.send_json(message, dumps=JSON_DUMP)
|
||||
|
||||
@callback
|
||||
def send_message_outside(self, message):
|
||||
|
@ -266,12 +263,11 @@ class ActiveConnection:
|
|||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle(self):
|
||||
async def handle(self):
|
||||
"""Handle the websocket connection."""
|
||||
request = self.request
|
||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
yield from wsock.prepare(request)
|
||||
await wsock.prepare(request)
|
||||
self.debug("Connected")
|
||||
|
||||
# Get a reference to current task so we can cancel our connection
|
||||
|
@ -294,8 +290,8 @@ class ActiveConnection:
|
|||
authenticated = True
|
||||
|
||||
else:
|
||||
yield from self.wsock.send_json(auth_required_message())
|
||||
msg = yield from wsock.receive_json()
|
||||
await self.wsock.send_json(auth_required_message())
|
||||
msg = await wsock.receive_json()
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
|
||||
if validate_password(request, msg['api_password']):
|
||||
|
@ -303,18 +299,18 @@ class ActiveConnection:
|
|||
|
||||
else:
|
||||
self.debug("Invalid password")
|
||||
yield from self.wsock.send_json(
|
||||
await self.wsock.send_json(
|
||||
auth_invalid_message('Invalid password'))
|
||||
|
||||
if not authenticated:
|
||||
yield from process_wrong_login(request)
|
||||
await process_wrong_login(request)
|
||||
return wsock
|
||||
|
||||
yield from self.wsock.send_json(auth_ok_message())
|
||||
await self.wsock.send_json(auth_ok_message())
|
||||
|
||||
# ---------- AUTH PHASE OVER ----------
|
||||
|
||||
msg = yield from wsock.receive_json()
|
||||
msg = await wsock.receive_json()
|
||||
last_id = 0
|
||||
|
||||
while msg:
|
||||
|
@ -332,7 +328,7 @@ class ActiveConnection:
|
|||
getattr(self, handler_name)(msg)
|
||||
|
||||
last_id = cur_id
|
||||
msg = yield from wsock.receive_json()
|
||||
msg = await wsock.receive_json()
|
||||
|
||||
except vol.Invalid as err:
|
||||
error_msg = "Message incorrectly formatted: "
|
||||
|
@ -394,11 +390,11 @@ class ActiveConnection:
|
|||
self.to_write.put_nowait(final_message)
|
||||
self.to_write.put_nowait(None)
|
||||
# Make sure all error messages are written before closing
|
||||
yield from self._writer_task
|
||||
await self._writer_task
|
||||
except asyncio.QueueFull:
|
||||
self._writer_task.cancel()
|
||||
|
||||
yield from wsock.close()
|
||||
await wsock.close()
|
||||
self.debug("Closed connection")
|
||||
|
||||
return wsock
|
||||
|
@ -410,8 +406,7 @@ class ActiveConnection:
|
|||
"""
|
||||
msg = SUBSCRIBE_EVENTS_MESSAGE_SCHEMA(msg)
|
||||
|
||||
@asyncio.coroutine
|
||||
def forward_events(event):
|
||||
async def forward_events(event):
|
||||
"""Forward events to websocket."""
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
@ -447,10 +442,9 @@ class ActiveConnection:
|
|||
"""
|
||||
msg = CALL_SERVICE_MESSAGE_SCHEMA(msg)
|
||||
|
||||
@asyncio.coroutine
|
||||
def call_service_helper(msg):
|
||||
async def call_service_helper(msg):
|
||||
"""Call a service and fire complete message."""
|
||||
yield from self.hass.services.async_call(
|
||||
await self.hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg.get('service_data'), True)
|
||||
self.send_message_outside(result_message(msg['id']))
|
||||
|
||||
|
@ -473,10 +467,9 @@ class ActiveConnection:
|
|||
"""
|
||||
msg = GET_SERVICES_MESSAGE_SCHEMA(msg)
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_services_helper(msg):
|
||||
async def get_services_helper(msg):
|
||||
"""Get available services and fire complete message."""
|
||||
descriptions = yield from async_get_all_descriptions(self.hass)
|
||||
descriptions = await async_get_all_descriptions(self.hass)
|
||||
self.send_message_outside(result_message(msg['id'], descriptions))
|
||||
|
||||
self.hass.async_add_job(get_services_helper(msg))
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Tests for the HTTP component."""
|
||||
import asyncio
|
||||
from ipaddress import ip_address
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -18,18 +17,16 @@ def mock_real_ip(app):
|
|||
nonlocal ip_to_mock
|
||||
ip_to_mock = value
|
||||
|
||||
@asyncio.coroutine
|
||||
@web.middleware
|
||||
def mock_real_ip(request, handler):
|
||||
async def mock_real_ip(request, handler):
|
||||
"""Mock Real IP middleware."""
|
||||
nonlocal ip_to_mock
|
||||
|
||||
request[KEY_REAL_IP] = ip_address(ip_to_mock)
|
||||
|
||||
return (yield from handler(request))
|
||||
return (await handler(request))
|
||||
|
||||
@asyncio.coroutine
|
||||
def real_ip_startup(app):
|
||||
async def real_ip_startup(app):
|
||||
"""Startup of real ip."""
|
||||
app.middlewares.insert(0, mock_real_ip)
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""The tests for the Home Assistant HTTP component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from ipaddress import ip_network
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -30,8 +29,7 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
|||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
async def mock_handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
if not request[KEY_AUTHENTICATED]:
|
||||
raise HTTPUnauthorized
|
||||
|
@ -47,84 +45,79 @@ def app():
|
|||
return app
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_auth_middleware_loaded_by_default(hass):
|
||||
async def test_auth_middleware_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_auth') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
await async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_without_password(app, test_client):
|
||||
async def test_access_without_password(app, test_client):
|
||||
"""Test access without password."""
|
||||
setup_auth(app, [], None)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_header(app, test_client):
|
||||
async def test_access_with_password_in_header(app, test_client):
|
||||
"""Test access with password in URL."""
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
assert req.status == 200
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_query(app, test_client):
|
||||
async def test_access_with_password_in_query(app, test_client):
|
||||
"""Test access without password."""
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
resp = yield from client.get('/', params={
|
||||
resp = await client.get('/', params={
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
assert resp.status == 200
|
||||
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 401
|
||||
|
||||
resp = yield from client.get('/', params={
|
||||
resp = await client.get('/', params={
|
||||
'api_password': 'wrong-password'
|
||||
})
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_basic_auth_works(app, test_client):
|
||||
async def test_basic_auth_works(app, test_client):
|
||||
"""Test access with basic authentication."""
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/',
|
||||
auth=BasicAuth('homeassistant', API_PASSWORD))
|
||||
assert req.status == 200
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/',
|
||||
auth=BasicAuth('wrong_username', API_PASSWORD))
|
||||
assert req.status == 401
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/',
|
||||
auth=BasicAuth('homeassistant', 'wrong password'))
|
||||
assert req.status == 401
|
||||
|
||||
req = yield from client.get(
|
||||
req = await client.get(
|
||||
'/',
|
||||
headers={
|
||||
'authorization': 'NotBasic abcdefg'
|
||||
|
@ -132,8 +125,7 @@ def test_basic_auth_works(app, test_client):
|
|||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_trusted_ip(test_client):
|
||||
async def test_access_with_trusted_ip(test_client):
|
||||
"""Test access with an untrusted ip address."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
|
@ -141,16 +133,16 @@ def test_access_with_trusted_ip(test_client):
|
|||
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
|
||||
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 401, \
|
||||
"{} shouldn't be trusted".format(remote_addr)
|
||||
|
||||
for remote_addr in TRUSTED_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 200, \
|
||||
"{} should be trusted".format(remote_addr)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""The tests for the Home Assistant HTTP component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -16,8 +15,7 @@ from . import mock_real_ip
|
|||
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_from_banned_ip(hass, test_client):
|
||||
async def test_access_from_banned_ip(hass, test_client):
|
||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||
app = web.Application()
|
||||
setup_bans(hass, app, 5)
|
||||
|
@ -26,19 +24,18 @@ def test_access_from_banned_ip(hass, test_client):
|
|||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||
return_value=[IpBan(banned_ip) for banned_ip
|
||||
in BANNED_IPS]):
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
for remote_addr in BANNED_IPS:
|
||||
set_real_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 403
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ban_middleware_not_loaded_by_config(hass):
|
||||
async def test_ban_middleware_not_loaded_by_config(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
http.CONF_IP_BAN_ENABLED: False,
|
||||
}
|
||||
|
@ -47,25 +44,22 @@ def test_ban_middleware_not_loaded_by_config(hass):
|
|||
assert len(mock_setup.mock_calls) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ban_middleware_loaded_by_default(hass):
|
||||
async def test_ban_middleware_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
await async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ip_bans_file_creation(hass, test_client):
|
||||
async def test_ip_bans_file_creation(hass, test_client):
|
||||
"""Testing if banned IP file created."""
|
||||
app = web.Application()
|
||||
app['hass'] = hass
|
||||
|
||||
@asyncio.coroutine
|
||||
def unauth_handler(request):
|
||||
async def unauth_handler(request):
|
||||
"""Return a mock web response."""
|
||||
raise HTTPUnauthorized
|
||||
|
||||
|
@ -76,21 +70,21 @@ def test_ip_bans_file_creation(hass, test_client):
|
|||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||
return_value=[IpBan(banned_ip) for banned_ip
|
||||
in BANNED_IPS]):
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
|
||||
m = mock_open()
|
||||
|
||||
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 401
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||
assert m.call_count == 0
|
||||
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 401
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||
|
||||
resp = yield from client.get('/')
|
||||
resp = await client.get('/')
|
||||
assert resp.status == 403
|
||||
assert m.call_count == 1
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Test cors for the HTTP component."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -20,22 +19,20 @@ from homeassistant.components.http.cors import setup_cors
|
|||
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_middleware_not_loaded_by_default(hass):
|
||||
async def test_cors_middleware_not_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
await async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_middleware_loaded_from_config(hass):
|
||||
async def test_cors_middleware_loaded_from_config(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'cors_allowed_origins': ['http://home-assistant.io']
|
||||
}
|
||||
|
@ -44,8 +41,7 @@ def test_cors_middleware_loaded_from_config(hass):
|
|||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
async def mock_handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.Response(status=200)
|
||||
|
||||
|
@ -59,10 +55,9 @@ def client(loop, test_client):
|
|||
return loop.run_until_complete(test_client(app))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_requests(client):
|
||||
async def test_cors_requests(client):
|
||||
"""Test cross origin requests."""
|
||||
req = yield from client.get('/', headers={
|
||||
req = await client.get('/', headers={
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
})
|
||||
assert req.status == 200
|
||||
|
@ -70,7 +65,7 @@ def test_cors_requests(client):
|
|||
TRUSTED_ORIGIN
|
||||
|
||||
# With password in URL
|
||||
req = yield from client.get('/', params={
|
||||
req = await client.get('/', params={
|
||||
'api_password': 'some-pass'
|
||||
}, headers={
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
|
@ -80,7 +75,7 @@ def test_cors_requests(client):
|
|||
TRUSTED_ORIGIN
|
||||
|
||||
# With password in headers
|
||||
req = yield from client.get('/', headers={
|
||||
req = await client.get('/', headers={
|
||||
HTTP_HEADER_HA_AUTH: 'some-pass',
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
})
|
||||
|
@ -89,10 +84,9 @@ def test_cors_requests(client):
|
|||
TRUSTED_ORIGIN
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_preflight_allowed(client):
|
||||
async def test_cors_preflight_allowed(client):
|
||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||
req = yield from client.options('/', headers={
|
||||
req = await client.options('/', headers={
|
||||
ORIGIN: TRUSTED_ORIGIN,
|
||||
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
||||
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""Test data validator decorator."""
|
||||
import asyncio
|
||||
from unittest.mock import Mock
|
||||
|
||||
from aiohttp import web
|
||||
|
@ -9,8 +8,7 @@ from homeassistant.components.http import HomeAssistantView
|
|||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_client(test_client, validator):
|
||||
async def get_client(test_client, validator):
|
||||
"""Generate a client that hits a view decorated with validator."""
|
||||
app = web.Application()
|
||||
app['hass'] = Mock(is_running=True)
|
||||
|
@ -20,58 +18,55 @@ def get_client(test_client, validator):
|
|||
name = 'test'
|
||||
requires_auth = False
|
||||
|
||||
@asyncio.coroutine
|
||||
@validator
|
||||
def post(self, request, data):
|
||||
async def post(self, request, data):
|
||||
"""Test method."""
|
||||
return b''
|
||||
|
||||
TestView().register(app.router)
|
||||
client = yield from test_client(app)
|
||||
client = await test_client(app)
|
||||
return client
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_validator(test_client):
|
||||
async def test_validator(test_client):
|
||||
"""Test the validator."""
|
||||
client = yield from get_client(
|
||||
client = await get_client(
|
||||
test_client, RequestDataValidator(vol.Schema({
|
||||
vol.Required('test'): str
|
||||
})))
|
||||
|
||||
resp = yield from client.post('/', json={
|
||||
resp = await client.post('/', json={
|
||||
'test': 'bla'
|
||||
})
|
||||
assert resp.status == 200
|
||||
|
||||
resp = yield from client.post('/', json={
|
||||
resp = await client.post('/', json={
|
||||
'test': 100
|
||||
})
|
||||
assert resp.status == 400
|
||||
|
||||
resp = yield from client.post('/')
|
||||
resp = await client.post('/')
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_validator_allow_empty(test_client):
|
||||
async def test_validator_allow_empty(test_client):
|
||||
"""Test the validator with empty data."""
|
||||
client = yield from get_client(
|
||||
client = await get_client(
|
||||
test_client, RequestDataValidator(vol.Schema({
|
||||
# Although we allow empty, our schema should still be able
|
||||
# to validate an empty dict.
|
||||
vol.Optional('test'): str
|
||||
}), allow_empty=True))
|
||||
|
||||
resp = yield from client.post('/', json={
|
||||
resp = await client.post('/', json={
|
||||
'test': 'bla'
|
||||
})
|
||||
assert resp.status == 200
|
||||
|
||||
resp = yield from client.post('/', json={
|
||||
resp = await client.post('/', json={
|
||||
'test': 100
|
||||
})
|
||||
assert resp.status == 400
|
||||
|
||||
resp = yield from client.post('/')
|
||||
resp = await client.post('/')
|
||||
assert resp.status == 200
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
"""The tests for the Home Assistant HTTP component."""
|
||||
import asyncio
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
import homeassistant.components.http as http
|
||||
|
@ -12,16 +10,14 @@ class TestView(http.HomeAssistantView):
|
|||
name = 'test'
|
||||
url = '/hello'
|
||||
|
||||
@asyncio.coroutine
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
"""Return a get request."""
|
||||
return 'hello'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_view_while_running(hass, test_client, unused_port):
|
||||
async def test_registering_view_while_running(hass, test_client, unused_port):
|
||||
"""Test that we can register a view while the server is running."""
|
||||
yield from async_setup_component(
|
||||
await async_setup_component(
|
||||
hass, http.DOMAIN, {
|
||||
http.DOMAIN: {
|
||||
http.CONF_SERVER_PORT: unused_port(),
|
||||
|
@ -29,15 +25,14 @@ def test_registering_view_while_running(hass, test_client, unused_port):
|
|||
}
|
||||
)
|
||||
|
||||
yield from hass.async_start()
|
||||
await hass.async_start()
|
||||
# This raises a RuntimeError if app is frozen
|
||||
hass.http.register_view(TestView)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_domain(hass):
|
||||
async def test_api_base_url_with_domain(hass):
|
||||
"""Test setting API URL."""
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
result = await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'base_url': 'example.com'
|
||||
}
|
||||
|
@ -46,10 +41,9 @@ def test_api_base_url_with_domain(hass):
|
|||
assert hass.config.api.base_url == 'http://example.com'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_ip(hass):
|
||||
async def test_api_base_url_with_ip(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
result = await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'server_host': '1.1.1.1'
|
||||
}
|
||||
|
@ -58,10 +52,9 @@ def test_api_base_url_with_ip(hass):
|
|||
assert hass.config.api.base_url == 'http://1.1.1.1:8123'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_ip_port(hass):
|
||||
async def test_api_base_url_with_ip_port(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
result = await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'base_url': '1.1.1.1:8124'
|
||||
}
|
||||
|
@ -70,10 +63,9 @@ def test_api_base_url_with_ip_port(hass):
|
|||
assert hass.config.api.base_url == 'http://1.1.1.1:8124'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_api_no_base_url(hass):
|
||||
async def test_api_no_base_url(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
result = await async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
}
|
||||
})
|
||||
|
@ -81,10 +73,9 @@ def test_api_no_base_url(hass):
|
|||
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||
async def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||
"""Test access with password doesn't get logged."""
|
||||
result = yield from async_setup_component(hass, 'api', {
|
||||
result = await async_setup_component(hass, 'api', {
|
||||
'http': {
|
||||
http.CONF_SERVER_PORT: unused_port(),
|
||||
http.CONF_API_PASSWORD: 'some-pass'
|
||||
|
@ -92,9 +83,9 @@ def test_not_log_password(hass, unused_port, test_client, caplog):
|
|||
})
|
||||
assert result
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
client = await test_client(hass.http.app)
|
||||
|
||||
resp = yield from client.get('/api/', params={
|
||||
resp = await client.get('/api/', params={
|
||||
'api_password': 'some-pass'
|
||||
})
|
||||
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
"""Test real IP middleware."""
|
||||
import asyncio
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||
|
||||
|
@ -8,41 +6,38 @@ from homeassistant.components.http.real_ip import setup_real_ip
|
|||
from homeassistant.components.http.const import KEY_REAL_IP
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
async def mock_handler(request):
|
||||
"""Handler that returns the real IP as text."""
|
||||
return web.Response(text=str(request[KEY_REAL_IP]))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ignore_x_forwarded_for(test_client):
|
||||
async def test_ignore_x_forwarded_for(test_client):
|
||||
"""Test that we get the IP from the transport."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, False)
|
||||
|
||||
mock_api_client = yield from test_client(app)
|
||||
mock_api_client = await test_client(app)
|
||||
|
||||
resp = yield from mock_api_client.get('/', headers={
|
||||
resp = await mock_api_client.get('/', headers={
|
||||
X_FORWARDED_FOR: '255.255.255.255'
|
||||
})
|
||||
assert resp.status == 200
|
||||
text = yield from resp.text()
|
||||
text = await resp.text()
|
||||
assert text != '255.255.255.255'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_use_x_forwarded_for(test_client):
|
||||
async def test_use_x_forwarded_for(test_client):
|
||||
"""Test that we get the IP from the transport."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, True)
|
||||
|
||||
mock_api_client = yield from test_client(app)
|
||||
mock_api_client = await test_client(app)
|
||||
|
||||
resp = yield from mock_api_client.get('/', headers={
|
||||
resp = await mock_api_client.get('/', headers={
|
||||
X_FORWARDED_FOR: '255.255.255.255'
|
||||
})
|
||||
assert resp.status == 200
|
||||
text = yield from resp.text()
|
||||
text = await resp.text()
|
||||
assert text == '255.255.255.255'
|
||||
|
|
Loading…
Reference in New Issue