Backend tweaks to make authorization work (#14339)
* Backend tweaks to make authorization work * Lint * Add test * Validate redirect uris * Fix tests * Fix tests * Lintpull/14377/head
parent
0f3ec94fba
commit
5ec7fc7ddb
homeassistant
components
helpers
|
@ -210,6 +210,7 @@ class Client:
|
|||
name = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
secret = attr.ib(type=str, default=attr.Factory(generate_secret))
|
||||
redirect_uris = attr.ib(type=list, default=attr.Factory(list))
|
||||
|
||||
|
||||
async def load_auth_provider_module(hass, provider):
|
||||
|
@ -340,9 +341,11 @@ class AuthManager:
|
|||
"""Get an access token."""
|
||||
return self.access_tokens.get(token)
|
||||
|
||||
async def async_create_client(self, name):
|
||||
async def async_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
"""Create a new client."""
|
||||
return await self._store.async_create_client(name)
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
|
@ -477,12 +480,20 @@ class AuthStore:
|
|||
|
||||
return None
|
||||
|
||||
async def async_create_client(self, name):
|
||||
async def async_create_client(self, name, redirect_uris, no_secret):
|
||||
"""Create a new client."""
|
||||
if self.clients is None:
|
||||
await self.async_load()
|
||||
|
||||
client = Client(name)
|
||||
kwargs = {
|
||||
'name': name,
|
||||
'redirect_uris': redirect_uris
|
||||
}
|
||||
|
||||
if no_secret:
|
||||
kwargs['secret'] = None
|
||||
|
||||
client = Client(**kwargs)
|
||||
self.clients[client.id] = client
|
||||
await self.async_save()
|
||||
return client
|
||||
|
|
|
@ -356,7 +356,8 @@ class APIErrorLog(HomeAssistantView):
|
|||
|
||||
async def get(self, request):
|
||||
"""Retrieve API error log."""
|
||||
return await self.file(request, request.app['hass'].data[DATA_LOGGING])
|
||||
return web.FileResponse(
|
||||
request.app['hass'].data[DATA_LOGGING])
|
||||
|
||||
|
||||
async def async_services_json(hass):
|
||||
|
|
|
@ -144,7 +144,7 @@ class AuthProvidersView(HomeAssistantView):
|
|||
requires_auth = False
|
||||
|
||||
@verify_client
|
||||
async def get(self, request, client_id):
|
||||
async def get(self, request, client):
|
||||
"""Get available auth providers."""
|
||||
return self.json([{
|
||||
'name': provider.name,
|
||||
|
@ -166,8 +166,15 @@ class LoginFlowIndexView(FlowManagerIndexView):
|
|||
|
||||
# pylint: disable=arguments-differ
|
||||
@verify_client
|
||||
async def post(self, request, client_id):
|
||||
@RequestDataValidator(vol.Schema({
|
||||
vol.Required('handler'): vol.Any(str, list),
|
||||
vol.Required('redirect_uri'): str,
|
||||
}))
|
||||
async def post(self, request, client, data):
|
||||
"""Create a new login flow."""
|
||||
if data['redirect_uri'] not in client.redirect_uris:
|
||||
return self.json_message('invalid redirect uri', )
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
return await super().post(request)
|
||||
|
||||
|
@ -192,7 +199,7 @@ class LoginFlowResourceView(FlowManagerResourceView):
|
|||
# pylint: disable=arguments-differ
|
||||
@verify_client
|
||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||
async def post(self, request, client_id, flow_id, data):
|
||||
async def post(self, request, client, flow_id, data):
|
||||
"""Handle progressing a login flow request."""
|
||||
try:
|
||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||
|
@ -205,7 +212,7 @@ class LoginFlowResourceView(FlowManagerResourceView):
|
|||
return self.json(self._prepare_result_json(result))
|
||||
|
||||
result.pop('data')
|
||||
result['result'] = self._store_credentials(client_id, result['result'])
|
||||
result['result'] = self._store_credentials(client.id, result['result'])
|
||||
|
||||
return self.json(result)
|
||||
|
||||
|
@ -222,7 +229,7 @@ class GrantTokenView(HomeAssistantView):
|
|||
self._retrieve_credentials = retrieve_credentials
|
||||
|
||||
@verify_client
|
||||
async def post(self, request, client_id):
|
||||
async def post(self, request, client):
|
||||
"""Grant a token."""
|
||||
hass = request.app['hass']
|
||||
data = await request.post()
|
||||
|
@ -230,11 +237,11 @@ class GrantTokenView(HomeAssistantView):
|
|||
|
||||
if grant_type == 'authorization_code':
|
||||
return await self._async_handle_auth_code(
|
||||
hass, client_id, data)
|
||||
hass, client.id, data)
|
||||
|
||||
elif grant_type == 'refresh_token':
|
||||
return await self._async_handle_refresh_token(
|
||||
hass, client_id, data)
|
||||
hass, client.id, data)
|
||||
|
||||
return self.json({
|
||||
'error': 'unsupported_grant_type',
|
||||
|
|
|
@ -11,15 +11,15 @@ def verify_client(method):
|
|||
@wraps(method)
|
||||
async def wrapper(view, request, *args, **kwargs):
|
||||
"""Verify client id/secret before doing request."""
|
||||
client_id = await _verify_client(request)
|
||||
client = await _verify_client(request)
|
||||
|
||||
if client_id is None:
|
||||
if client is None:
|
||||
return view.json({
|
||||
'error': 'invalid_client',
|
||||
}, status_code=401)
|
||||
|
||||
return await method(
|
||||
view, request, *args, client_id=client_id, **kwargs)
|
||||
view, request, *args, **kwargs, client=client)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
@ -46,18 +46,34 @@ async def _verify_client(request):
|
|||
client_id, client_secret = decoded.split(':', 1)
|
||||
except ValueError:
|
||||
# If no ':' in decoded
|
||||
return None
|
||||
client_id, client_secret = decoded, None
|
||||
|
||||
client = await request.app['hass'].auth.async_get_client(client_id)
|
||||
return await async_secure_get_client(
|
||||
request.app['hass'], client_id, client_secret)
|
||||
|
||||
|
||||
async def async_secure_get_client(hass, client_id, client_secret):
|
||||
"""Get a client id/secret in consistent time."""
|
||||
client = await hass.auth.async_get_client(client_id)
|
||||
|
||||
if client is None:
|
||||
# Still do a compare so we run same time as if a client was found.
|
||||
hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client_secret.encode('utf-8'))
|
||||
if client_secret is not None:
|
||||
# Still do a compare so we run same time as if a client was found.
|
||||
hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client_secret.encode('utf-8'))
|
||||
return None
|
||||
|
||||
if hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client.secret.encode('utf-8')):
|
||||
return client_id
|
||||
if client.secret is None:
|
||||
return client
|
||||
|
||||
elif client_secret is None:
|
||||
# Still do a compare so we run same time as if a secret was passed.
|
||||
hmac.compare_digest(client.secret.encode('utf-8'),
|
||||
client.secret.encode('utf-8'))
|
||||
return None
|
||||
|
||||
elif hmac.compare_digest(client_secret.encode('utf-8'),
|
||||
client.secret.encode('utf-8')):
|
||||
return client
|
||||
|
||||
return None
|
||||
|
|
|
@ -296,6 +296,15 @@ def add_manifest_json_key(key, val):
|
|||
@asyncio.coroutine
|
||||
def async_setup(hass, config):
|
||||
"""Set up the serving of the frontend."""
|
||||
if list(hass.auth.async_auth_providers):
|
||||
client = yield from hass.auth.async_create_client(
|
||||
'Home Assistant Frontend',
|
||||
redirect_uris=['/'],
|
||||
no_secret=True,
|
||||
)
|
||||
else:
|
||||
client = None
|
||||
|
||||
hass.components.websocket_api.async_register_command(
|
||||
WS_TYPE_GET_PANELS, websocket_handle_get_panels, SCHEMA_GET_PANELS)
|
||||
hass.http.register_view(ManifestJSONView)
|
||||
|
@ -353,7 +362,7 @@ def async_setup(hass, config):
|
|||
if os.path.isdir(local):
|
||||
hass.http.register_static_path("/local", local, not is_dev)
|
||||
|
||||
index_view = IndexView(repo_path, js_version)
|
||||
index_view = IndexView(repo_path, js_version, client)
|
||||
hass.http.register_view(index_view)
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -451,10 +460,11 @@ class IndexView(HomeAssistantView):
|
|||
requires_auth = False
|
||||
extra_urls = ['/states', '/states/{extra}']
|
||||
|
||||
def __init__(self, repo_path, js_option):
|
||||
def __init__(self, repo_path, js_option, client):
|
||||
"""Initialize the frontend view."""
|
||||
self.repo_path = repo_path
|
||||
self.js_option = js_option
|
||||
self.client = client
|
||||
self._template_cache = {}
|
||||
|
||||
def get_template(self, latest):
|
||||
|
@ -508,7 +518,7 @@ class IndexView(HomeAssistantView):
|
|||
|
||||
extra_key = DATA_EXTRA_HTML_URL if latest else DATA_EXTRA_HTML_URL_ES5
|
||||
|
||||
resp = template.render(
|
||||
template_params = dict(
|
||||
no_auth=no_auth,
|
||||
panel_url=panel_url,
|
||||
panels=hass.data[DATA_PANELS],
|
||||
|
@ -516,7 +526,11 @@ class IndexView(HomeAssistantView):
|
|||
extra_urls=hass.data[extra_key],
|
||||
)
|
||||
|
||||
return web.Response(text=resp, content_type='text/html')
|
||||
if self.client is not None:
|
||||
template_params['client_id'] = self.client.id
|
||||
|
||||
return web.Response(text=template.render(**template_params),
|
||||
content_type='text/html')
|
||||
|
||||
|
||||
class ManifestJSONView(HomeAssistantView):
|
||||
|
|
|
@ -81,7 +81,12 @@ async def async_validate_auth_header(api_password, request):
|
|||
if hdrs.AUTHORIZATION not in request.headers:
|
||||
return False
|
||||
|
||||
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(' ', 1)
|
||||
try:
|
||||
auth_type, auth_val = \
|
||||
request.headers.get(hdrs.AUTHORIZATION).split(' ', 1)
|
||||
except ValueError:
|
||||
# If no space in authorization header
|
||||
return False
|
||||
|
||||
if auth_type == 'Basic':
|
||||
decoded = base64.b64decode(auth_val).decode('utf-8')
|
||||
|
|
|
@ -51,12 +51,6 @@ class HomeAssistantView(object):
|
|||
data['code'] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
async def file(self, request, fil):
|
||||
"""Return a file."""
|
||||
assert isinstance(fil, str), 'only string paths allowed'
|
||||
return web.FileResponse(fil)
|
||||
|
||||
def register(self, router):
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, 'No url set for view'
|
||||
|
|
|
@ -60,7 +60,8 @@ JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
|||
|
||||
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('type'): TYPE_AUTH,
|
||||
vol.Required('api_password'): str,
|
||||
vol.Exclusive('api_password', 'auth'): str,
|
||||
vol.Exclusive('access_token', 'auth'): str,
|
||||
})
|
||||
|
||||
# Minimal requirements of a message
|
||||
|
@ -318,15 +319,18 @@ class ActiveConnection:
|
|||
msg = await wsock.receive_json()
|
||||
msg = AUTH_MESSAGE_SCHEMA(msg)
|
||||
|
||||
if validate_password(request, msg['api_password']):
|
||||
authenticated = True
|
||||
if 'api_password' in msg:
|
||||
authenticated = validate_password(
|
||||
request, msg['api_password'])
|
||||
|
||||
else:
|
||||
self.debug("Invalid password")
|
||||
await self.wsock.send_json(
|
||||
auth_invalid_message('Invalid password'))
|
||||
elif 'access_token' in msg:
|
||||
authenticated = \
|
||||
msg['access_token'] in self.hass.auth.access_tokens
|
||||
|
||||
if not authenticated:
|
||||
self.debug("Invalid password")
|
||||
await self.wsock.send_json(
|
||||
auth_invalid_message('Invalid password'))
|
||||
await process_wrong_login(request)
|
||||
return wsock
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
|
|||
|
||||
@RequestDataValidator(vol.Schema({
|
||||
vol.Required('handler'): vol.Any(str, list),
|
||||
}))
|
||||
}, extra=vol.ALLOW_EXTRA))
|
||||
async def post(self, request, data):
|
||||
"""Handle a POST request."""
|
||||
if isinstance(data['handler'], list):
|
||||
|
|
|
@ -19,6 +19,7 @@ BASE_CONFIG = [{
|
|||
CLIENT_ID = 'test-id'
|
||||
CLIENT_SECRET = 'test-secret'
|
||||
CLIENT_AUTH = BasicAuth(CLIENT_ID, CLIENT_SECRET)
|
||||
CLIENT_REDIRECT_URI = 'http://example.com/callback'
|
||||
|
||||
|
||||
async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
||||
|
@ -31,7 +32,8 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
|
|||
'api_password': 'bla'
|
||||
}
|
||||
})
|
||||
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET)
|
||||
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
|
||||
redirect_uris=[CLIENT_REDIRECT_URI])
|
||||
hass.auth._store.clients[client.id] = client
|
||||
if setup_api:
|
||||
await async_setup_component(hass, 'api', {})
|
||||
|
|
|
@ -21,9 +21,9 @@ def mock_view(hass):
|
|||
name = 'bla'
|
||||
|
||||
@verify_client
|
||||
async def get(self, request, client_id):
|
||||
async def get(self, request, client):
|
||||
"""Handle GET request."""
|
||||
clients.append(client_id)
|
||||
clients.append(client)
|
||||
|
||||
hass.http.register_view(ClientView)
|
||||
return clients
|
||||
|
@ -36,7 +36,7 @@ async def test_verify_client(hass, aiohttp_client, mock_view):
|
|||
|
||||
resp = await http_client.get('/', auth=BasicAuth(client.id, client.secret))
|
||||
assert resp.status == 200
|
||||
assert mock_view == [client.id]
|
||||
assert mock_view[0] is client
|
||||
|
||||
|
||||
async def test_verify_client_no_auth_header(hass, aiohttp_client, mock_view):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
"""Integration tests for the auth component."""
|
||||
from . import async_setup_auth, CLIENT_AUTH
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
|
||||
"""Test logging in with new user and refreshing tokens."""
|
||||
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'handler': ['insecure_example', None]
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""Tests for the link user flow."""
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def async_get_code(hass, aiohttp_client):
|
||||
|
@ -25,7 +25,8 @@ async def async_get_code(hass, aiohttp_client):
|
|||
client = await async_setup_auth(hass, aiohttp_client, config)
|
||||
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'handler': ['insecure_example', None]
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
@ -56,7 +57,8 @@ async def async_get_code(hass, aiohttp_client):
|
|||
|
||||
# Now authenticate with the 2nd flow
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'handler': ['insecure_example', '2nd auth']
|
||||
'handler': ['insecure_example', '2nd auth'],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI,
|
||||
}, auth=CLIENT_AUTH)
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Tests for the login flow."""
|
||||
from aiohttp.helpers import BasicAuth
|
||||
|
||||
from . import async_setup_auth, CLIENT_AUTH
|
||||
from . import async_setup_auth, CLIENT_AUTH, CLIENT_REDIRECT_URI
|
||||
|
||||
|
||||
async def test_fetch_auth_providers(hass, aiohttp_client):
|
||||
|
@ -34,7 +34,8 @@ async def test_invalid_username_password(hass, aiohttp_client):
|
|||
"""Test we cannot get flows in progress."""
|
||||
client = await async_setup_auth(hass, aiohttp_client)
|
||||
resp = await client.post('/auth/login_flow', json={
|
||||
'handler': ['insecure_example', None]
|
||||
'handler': ['insecure_example', None],
|
||||
'redirect_uri': CLIENT_REDIRECT_URI
|
||||
}, auth=CLIENT_AUTH)
|
||||
assert resp.status == 200
|
||||
step = await resp.json()
|
||||
|
|
|
@ -3,6 +3,8 @@ import pytest
|
|||
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockUser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_ws_client(aiohttp_client):
|
||||
|
@ -20,3 +22,17 @@ def hass_ws_client(aiohttp_client):
|
|||
return websocket
|
||||
|
||||
return create_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hass_access_token(hass):
|
||||
"""Return an access token to access Home Assistant."""
|
||||
user = MockUser().add_to_hass(hass)
|
||||
client = hass.loop.run_until_complete(hass.auth.async_create_client(
|
||||
'Access Token Fixture',
|
||||
redirect_uris=['/'],
|
||||
no_secret=True,
|
||||
))
|
||||
refresh_token = hass.loop.run_until_complete(
|
||||
hass.auth.async_create_refresh_token(user, client.id))
|
||||
yield hass.auth.async_create_access_token(refresh_token)
|
||||
|
|
|
@ -12,8 +12,6 @@ from homeassistant.bootstrap import DATA_LOGGING
|
|||
import homeassistant.core as ha
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client(hass, aiohttp_client):
|
||||
|
@ -420,14 +418,14 @@ async def test_api_error_log(hass, aiohttp_client):
|
|||
assert resp.status == 401
|
||||
|
||||
with patch(
|
||||
'homeassistant.components.http.view.HomeAssistantView.file',
|
||||
return_value=mock_coro(web.Response(status=200, text='Hello'))
|
||||
'aiohttp.web.FileResponse',
|
||||
return_value=web.Response(status=200, text='Hello')
|
||||
) as mock_file:
|
||||
resp = await client.get(const.URL_API_ERROR_LOG, headers={
|
||||
'x-ha-access': 'yolo'
|
||||
})
|
||||
|
||||
assert len(mock_file.mock_calls) == 1
|
||||
assert mock_file.mock_calls[0][1][1] == hass.data[DATA_LOGGING]
|
||||
assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING]
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == 'Hello'
|
||||
|
|
|
@ -313,3 +313,49 @@ def test_unknown_command(websocket_client):
|
|||
|
||||
msg = yield from websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
async def test_auth_with_token(hass, aiohttp_client, hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token.token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': 'incorrect'
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
|
Loading…
Reference in New Issue