Use JWT for access tokens (#15972)
* Use JWT for access tokens * Update requirements * Improvementspull/15978/head
parent
ee5d49a033
commit
e776f88eec
|
@ -4,10 +4,12 @@ import logging
|
|||
from collections import OrderedDict
|
||||
from typing import List, Awaitable
|
||||
|
||||
import jwt
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import models
|
||||
from . import auth_store
|
||||
from .providers import auth_provider_from_config
|
||||
|
||||
|
@ -54,7 +56,6 @@ class AuthManager:
|
|||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow,
|
||||
self._async_finish_login_flow)
|
||||
self._access_tokens = OrderedDict()
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
|
@ -181,35 +182,56 @@ class AuthManager:
|
|||
|
||||
return await self._store.async_create_refresh_token(user, client_id)
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
async def async_get_refresh_token(self, token_id):
|
||||
"""Get refresh token by id."""
|
||||
return await self._store.async_get_refresh_token(token_id)
|
||||
|
||||
async def async_get_refresh_token_by_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token(token)
|
||||
return await self._store.async_get_refresh_token_by_token(token)
|
||||
|
||||
@callback
|
||||
def async_create_access_token(self, refresh_token):
|
||||
"""Create a new access token."""
|
||||
access_token = models.AccessToken(refresh_token=refresh_token)
|
||||
self._access_tokens[access_token.token] = access_token
|
||||
return access_token
|
||||
# pylint: disable=no-self-use
|
||||
return jwt.encode({
|
||||
'iss': refresh_token.id,
|
||||
'iat': dt_util.utcnow(),
|
||||
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
|
||||
}, refresh_token.jwt_key, algorithm='HS256').decode()
|
||||
|
||||
@callback
|
||||
def async_get_access_token(self, token):
|
||||
"""Get an access token."""
|
||||
tkn = self._access_tokens.get(token)
|
||||
|
||||
if tkn is None:
|
||||
_LOGGER.debug('Attempt to get non-existing access token')
|
||||
async def async_validate_access_token(self, token):
|
||||
"""Return if an access token is valid."""
|
||||
try:
|
||||
unverif_claims = jwt.decode(token, verify=False)
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
if tkn.expired or not tkn.refresh_token.user.is_active:
|
||||
if tkn.expired:
|
||||
_LOGGER.debug('Attempt to get expired access token')
|
||||
else:
|
||||
_LOGGER.debug('Attempt to get access token for inactive user')
|
||||
self._access_tokens.pop(token)
|
||||
refresh_token = await self.async_get_refresh_token(
|
||||
unverif_claims.get('iss'))
|
||||
|
||||
if refresh_token is None:
|
||||
jwt_key = ''
|
||||
issuer = ''
|
||||
else:
|
||||
jwt_key = refresh_token.jwt_key
|
||||
issuer = refresh_token.id
|
||||
|
||||
try:
|
||||
jwt.decode(
|
||||
token,
|
||||
jwt_key,
|
||||
leeway=10,
|
||||
issuer=issuer,
|
||||
algorithms=['HS256']
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
return tkn
|
||||
if not refresh_token.user.is_active:
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
async def _async_create_login_flow(self, handler, *, context, data):
|
||||
"""Create a login flow."""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Storage for auth models."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
import hmac
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
|
@ -110,22 +111,36 @@ class AuthStore:
|
|||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new token for a user."""
|
||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||
await self.async_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
async def async_get_refresh_token(self, token_id):
|
||||
"""Get refresh token by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token)
|
||||
refresh_token = user.refresh_tokens.get(token_id)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
|
||||
return None
|
||||
|
||||
async def async_get_refresh_token_by_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
found = None
|
||||
|
||||
for user in self._users.values():
|
||||
for refresh_token in user.refresh_tokens.values():
|
||||
if hmac.compare_digest(refresh_token.token, token):
|
||||
found = refresh_token
|
||||
|
||||
return found
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
|
@ -153,9 +168,11 @@ class AuthStore:
|
|||
data=cred_dict['data'],
|
||||
))
|
||||
|
||||
refresh_tokens = OrderedDict()
|
||||
|
||||
for rt_dict in data['refresh_tokens']:
|
||||
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
||||
if 'jwt_key' not in rt_dict:
|
||||
continue
|
||||
|
||||
token = models.RefreshToken(
|
||||
id=rt_dict['id'],
|
||||
user=users[rt_dict['user_id']],
|
||||
|
@ -164,18 +181,9 @@ class AuthStore:
|
|||
access_token_expiration=timedelta(
|
||||
seconds=rt_dict['access_token_expiration']),
|
||||
token=rt_dict['token'],
|
||||
jwt_key=rt_dict['jwt_key']
|
||||
)
|
||||
refresh_tokens[token.id] = token
|
||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
||||
|
||||
for ac_dict in data['access_tokens']:
|
||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
||||
token = models.AccessToken(
|
||||
refresh_token=refresh_token,
|
||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
||||
token=ac_dict['token'],
|
||||
)
|
||||
refresh_token.access_tokens.append(token)
|
||||
users[rt_dict['user_id']].refresh_tokens[token.id] = token
|
||||
|
||||
self._users = users
|
||||
|
||||
|
@ -213,27 +221,15 @@ class AuthStore:
|
|||
'access_token_expiration':
|
||||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
'jwt_key': refresh_token.jwt_key,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
access_tokens = [
|
||||
{
|
||||
'id': user.id,
|
||||
'refresh_token_id': refresh_token.id,
|
||||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
||||
data = {
|
||||
'users': users,
|
||||
'credentials': credentials,
|
||||
'access_tokens': access_tokens,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
}
|
||||
|
||||
|
|
|
@ -39,26 +39,8 @@ class RefreshToken:
|
|||
default=ACCESS_TOKEN_EXPIRATION)
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(lambda: generate_secret(64)))
|
||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class AccessToken:
|
||||
"""Access token to access the API.
|
||||
|
||||
These will only ever be stored in memory and not be persisted.
|
||||
"""
|
||||
|
||||
refresh_token = attr.ib(type=RefreshToken)
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(generate_secret))
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Return if this token has expired."""
|
||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
||||
return dt_util.utcnow() > expires
|
||||
jwt_key = attr.ib(type=str,
|
||||
default=attr.Factory(lambda: generate_secret(64)))
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
|
|
@ -155,7 +155,7 @@ class GrantTokenView(HomeAssistantView):
|
|||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
return self.json({
|
||||
'access_token': access_token.token,
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'refresh_token': refresh_token.token,
|
||||
'expires_in':
|
||||
|
@ -178,7 +178,7 @@ class GrantTokenView(HomeAssistantView):
|
|||
'error': 'invalid_request',
|
||||
}, status_code=400)
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(token)
|
||||
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
||||
|
||||
if refresh_token is None:
|
||||
return self.json({
|
||||
|
@ -193,7 +193,7 @@ class GrantTokenView(HomeAssistantView):
|
|||
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||
|
||||
return self.json({
|
||||
'access_token': access_token.token,
|
||||
'access_token': access_token,
|
||||
'token_type': 'Bearer',
|
||||
'expires_in':
|
||||
int(refresh_token.access_token_expiration.total_seconds()),
|
||||
|
|
|
@ -106,11 +106,11 @@ async def async_validate_auth_header(request, api_password=None):
|
|||
|
||||
if auth_type == 'Bearer':
|
||||
hass = request.app['hass']
|
||||
access_token = hass.auth.async_get_access_token(auth_val)
|
||||
if access_token is None:
|
||||
refresh_token = await hass.auth.async_validate_access_token(auth_val)
|
||||
if refresh_token is None:
|
||||
return False
|
||||
|
||||
request['hass_user'] = access_token.refresh_token.user
|
||||
request['hass_user'] = refresh_token.user
|
||||
return True
|
||||
|
||||
if auth_type == 'Basic' and api_password is not None:
|
||||
|
|
|
@ -355,11 +355,12 @@ class ActiveConnection:
|
|||
|
||||
if self.hass.auth.active and 'access_token' in msg:
|
||||
self.debug("Received access_token")
|
||||
token = self.hass.auth.async_get_access_token(
|
||||
msg['access_token'])
|
||||
authenticated = token is not None
|
||||
refresh_token = \
|
||||
await self.hass.auth.async_validate_access_token(
|
||||
msg['access_token'])
|
||||
authenticated = refresh_token is not None
|
||||
if authenticated:
|
||||
request['hass_user'] = token.refresh_token.user
|
||||
request['hass_user'] = refresh_token.user
|
||||
|
||||
elif ((not self.hass.auth.active or
|
||||
self.hass.auth.support_legacy) and
|
||||
|
|
|
@ -4,6 +4,7 @@ async_timeout==3.0.0
|
|||
attrs==18.1.0
|
||||
certifi>=2018.04.16
|
||||
jinja2>=2.10
|
||||
PyJWT==1.6.4
|
||||
pip>=8.0.3
|
||||
pytz>=2018.04
|
||||
pyyaml>=3.13,<4
|
||||
|
|
|
@ -5,6 +5,7 @@ async_timeout==3.0.0
|
|||
attrs==18.1.0
|
||||
certifi>=2018.04.16
|
||||
jinja2>=2.10
|
||||
PyJWT==1.6.4
|
||||
pip>=8.0.3
|
||||
pytz>=2018.04
|
||||
pyyaml>=3.13,<4
|
||||
|
|
1
setup.py
1
setup.py
|
@ -38,6 +38,7 @@ REQUIRES = [
|
|||
'attrs==18.1.0',
|
||||
'certifi>=2018.04.16',
|
||||
'jinja2>=2.10',
|
||||
'PyJWT==1.6.4',
|
||||
'pip>=8.0.3',
|
||||
'pytz>=2018.04',
|
||||
'pyyaml>=3.13,<4',
|
||||
|
|
|
@ -199,9 +199,7 @@ async def test_saving_loading(hass, hass_storage):
|
|||
})
|
||||
user = await manager.async_get_or_create_user(step['result'])
|
||||
await manager.async_activate_user(user)
|
||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
manager.async_create_access_token(refresh_token)
|
||||
await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
|
||||
await flush_store(manager._store._store)
|
||||
|
||||
|
@ -211,30 +209,6 @@ async def test_saving_loading(hass, hass_storage):
|
|||
assert users[0] == user
|
||||
|
||||
|
||||
def test_access_token_expired():
|
||||
"""Test that the expired property on access tokens work."""
|
||||
refresh_token = auth_models.RefreshToken(
|
||||
user=None,
|
||||
client_id='bla'
|
||||
)
|
||||
|
||||
access_token = auth_models.AccessToken(
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
assert access_token.expired is False
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=dt_util.utcnow() +
|
||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||
assert access_token.expired is True
|
||||
|
||||
almost_exp = \
|
||||
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
||||
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
|
||||
assert access_token.expired is False
|
||||
|
||||
|
||||
async def test_cannot_retrieve_expired_access_token(hass):
|
||||
"""Test that we cannot retrieve expired access tokens."""
|
||||
manager = await auth.auth_manager_from_config(hass, [])
|
||||
|
@ -244,15 +218,20 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
|||
assert refresh_token.client_id == CLIENT_ID
|
||||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
assert manager.async_get_access_token(access_token.token) is access_token
|
||||
assert (
|
||||
await manager.async_validate_access_token(access_token)
|
||||
is refresh_token
|
||||
)
|
||||
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=dt_util.utcnow() +
|
||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
return_value=dt_util.utcnow() -
|
||||
auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(seconds=11)):
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
# Even with unpatched time, it should have been removed from manager
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
assert (
|
||||
await manager.async_validate_access_token(access_token)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
async def test_generating_system_user(hass):
|
||||
|
|
|
@ -314,12 +314,18 @@ def mock_registry(hass, mock_entries=None):
|
|||
class MockUser(auth_models.User):
|
||||
"""Mock a user in Home Assistant."""
|
||||
|
||||
def __init__(self, id='mock-id', is_owner=False, is_active=True,
|
||||
def __init__(self, id=None, is_owner=False, is_active=True,
|
||||
name='Mock User', system_generated=False):
|
||||
"""Initialize mock user."""
|
||||
super().__init__(
|
||||
id=id, is_owner=is_owner, is_active=is_active, name=name,
|
||||
system_generated=system_generated)
|
||||
kwargs = {
|
||||
'is_owner': is_owner,
|
||||
'is_active': is_active,
|
||||
'name': name,
|
||||
'system_generated': system_generated
|
||||
}
|
||||
if id is not None:
|
||||
kwargs['id'] = id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def add_to_hass(self, hass):
|
||||
"""Test helper to add entry to hass."""
|
||||
|
|
|
@ -44,7 +44,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||
is not None
|
||||
)
|
||||
|
||||
# Use refresh token to get more tokens.
|
||||
resp = await client.post('/auth/token', data={
|
||||
|
@ -56,7 +59,10 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
assert 'refresh_token' not in tokens
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||
is not None
|
||||
)
|
||||
|
||||
# Test using access token to hit API.
|
||||
resp = await client.get('/api/')
|
||||
|
@ -98,7 +104,9 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
|||
}
|
||||
})
|
||||
|
||||
user = hass_access_token.refresh_token.user
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
user = refresh_token.user
|
||||
credential = Credentials(auth_provider_type='homeassistant',
|
||||
auth_provider_id=None,
|
||||
data={}, id='test-id')
|
||||
|
@ -169,7 +177,10 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
|
|||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||
|
@ -208,4 +219,7 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
|||
|
||||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens['access_token'])
|
||||
is not None
|
||||
)
|
||||
|
|
|
@ -52,7 +52,7 @@ async def async_get_code(hass, aiohttp_client):
|
|||
'user': user,
|
||||
'code': step['result'],
|
||||
'client': client,
|
||||
'access_token': access_token.token,
|
||||
'access_token': access_token,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -122,11 +122,13 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
|
|||
hass_access_token):
|
||||
"""Test we cannot delete our own account."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
'type': auth_config.WS_TYPE_DELETE,
|
||||
'user_id': hass_access_token.refresh_token.user.id,
|
||||
'user_id': refresh_token.user.id,
|
||||
})
|
||||
|
||||
result = await client.receive_json()
|
||||
|
@ -137,7 +139,9 @@ async def test_delete_unable_self_account(hass, hass_ws_client,
|
|||
async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
||||
"""Test we cannot delete an unknown user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
|
@ -153,7 +157,9 @@ async def test_delete_unknown_user(hass, hass_ws_client, hass_access_token):
|
|||
async def test_delete(hass, hass_ws_client, hass_access_token):
|
||||
"""Test delete command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
test_user = MockUser(
|
||||
id='efg',
|
||||
).add_to_hass(hass)
|
||||
|
@ -174,7 +180,9 @@ async def test_delete(hass, hass_ws_client, hass_access_token):
|
|||
async def test_create(hass, hass_ws_client, hass_access_token):
|
||||
"""Test create command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
assert len(await hass.auth.async_get_users()) == 1
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from tests.common import MockUser, register_auth_provider
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_config(hass, aiohttp_client):
|
||||
def setup_config(hass):
|
||||
"""Fixture that sets up the auth provider homeassistant module."""
|
||||
hass.loop.run_until_complete(register_auth_provider(hass, {
|
||||
'type': 'homeassistant'
|
||||
|
@ -22,7 +22,9 @@ async def test_create_auth_system_generated_user(hass, hass_access_token,
|
|||
"""Test we can't add auth to system generated users."""
|
||||
system_user = MockUser(system_generated=True).add_to_hass(hass)
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
|
@ -47,7 +49,9 @@ async def test_create_auth_unknown_user(hass_ws_client, hass,
|
|||
hass_access_token):
|
||||
"""Test create pointing at unknown user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
|
@ -86,7 +90,9 @@ async def test_create_auth(hass, hass_ws_client, hass_access_token,
|
|||
"""Test create auth command works."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
user = MockUser().add_to_hass(hass)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
assert len(user.credentials) == 0
|
||||
|
||||
|
@ -117,7 +123,9 @@ async def test_create_auth_duplicate_username(hass, hass_ws_client,
|
|||
"""Test we can't create auth with a duplicate username."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
user = MockUser().add_to_hass(hass)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||
'version': 1,
|
||||
|
@ -145,7 +153,9 @@ async def test_delete_removes_just_auth(hass_ws_client, hass, hass_storage,
|
|||
hass_access_token):
|
||||
"""Test deleting an auth without being connected to a user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
hass_storage[prov_ha.STORAGE_KEY] = {
|
||||
'version': 1,
|
||||
|
@ -171,7 +181,9 @@ async def test_delete_removes_credential(hass, hass_ws_client,
|
|||
hass_access_token, hass_storage):
|
||||
"""Test deleting auth that is connected to a user."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
user = MockUser().add_to_hass(hass)
|
||||
user.credentials.append(
|
||||
|
@ -216,7 +228,9 @@ async def test_delete_requires_owner(hass, hass_ws_client, hass_access_token):
|
|||
async def test_delete_unknown_auth(hass, hass_ws_client, hass_access_token):
|
||||
"""Test trying to delete an unknown auth username."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
hass_access_token.refresh_token.user.is_owner = True
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_owner = True
|
||||
|
||||
await client.send_json({
|
||||
'id': 5,
|
||||
|
@ -240,7 +254,9 @@ async def test_change_password(hass, hass_ws_client, hass_access_token):
|
|||
'username': 'test-user'
|
||||
})
|
||||
|
||||
user = hass_access_token.refresh_token.user
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
user = refresh_token.user
|
||||
await hass.auth.async_link_user(user, credentials)
|
||||
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
@ -268,7 +284,9 @@ async def test_change_password_wrong_pw(hass, hass_ws_client,
|
|||
'username': 'test-user'
|
||||
})
|
||||
|
||||
user = hass_access_token.refresh_token.user
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
user = refresh_token.user
|
||||
await hass.auth.async_link_user(user, credentials)
|
||||
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
|
|
@ -28,7 +28,7 @@ def hass_ws_client(aiohttp_client):
|
|||
|
||||
await websocket.send_json({
|
||||
'type': websocket_api.TYPE_AUTH,
|
||||
'access_token': access_token.token
|
||||
'access_token': access_token
|
||||
})
|
||||
|
||||
auth_ok = await websocket.receive_json()
|
||||
|
|
|
@ -106,7 +106,11 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
|
|||
)
|
||||
assert hassio_user is not None
|
||||
assert hassio_user.system_generated
|
||||
assert refresh_token in hassio_user.refresh_tokens
|
||||
for token in hassio_user.refresh_tokens.values():
|
||||
if token.token == refresh_token:
|
||||
break
|
||||
else:
|
||||
assert False, 'refresh token not found'
|
||||
|
||||
|
||||
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
|
||||
|
|
|
@ -156,9 +156,9 @@ async def test_access_with_trusted_ip(app2, aiohttp_client):
|
|||
|
||||
|
||||
async def test_auth_active_access_with_access_token_in_header(
|
||||
app, aiohttp_client, hass_access_token):
|
||||
hass, app, aiohttp_client, hass_access_token):
|
||||
"""Test access with access token in header."""
|
||||
token = hass_access_token.token
|
||||
token = hass_access_token
|
||||
setup_auth(app, [], True, api_password=None)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
|
@ -182,7 +182,9 @@ async def test_auth_active_access_with_access_token_in_header(
|
|||
'/', headers={'Authorization': 'BEARER {}'.format(token)})
|
||||
assert req.status == 401
|
||||
|
||||
hass_access_token.refresh_token.user.is_active = False
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
req = await client.get(
|
||||
'/', headers={'Authorization': 'Bearer {}'.format(token)})
|
||||
assert req.status == 401
|
||||
|
|
|
@ -448,13 +448,15 @@ async def test_api_fire_event_context(hass, mock_api_client,
|
|||
await mock_api_client.post(
|
||||
const.URL_API_EVENTS_EVENT.format("test.event"),
|
||||
headers={
|
||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||
})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
assert len(test_value) == 1
|
||||
assert test_value[0].context.user_id == \
|
||||
hass_access_token.refresh_token.user.id
|
||||
assert test_value[0].context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_api_call_service_context(hass, mock_api_client,
|
||||
|
@ -465,12 +467,15 @@ async def test_api_call_service_context(hass, mock_api_client,
|
|||
await mock_api_client.post(
|
||||
'/api/services/test_domain/test_service',
|
||||
headers={
|
||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||
})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context.user_id == hass_access_token.refresh_token.user.id
|
||||
assert calls[0].context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
||||
|
@ -481,8 +486,11 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
|||
'state': 'on'
|
||||
},
|
||||
headers={
|
||||
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||
'authorization': 'Bearer {}'.format(hass_access_token)
|
||||
})
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
state = hass.states.get('light.kitchen')
|
||||
assert state.context.user_id == hass_access_token.refresh_token.user.id
|
||||
assert state.context.user_id == refresh_token.user.id
|
||||
|
|
|
@ -334,7 +334,7 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
|||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token.token
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
|
@ -344,7 +344,9 @@ async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
|||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||
hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
hass_access_token.refresh_token.user.is_active = False
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
|
@ -361,7 +363,7 @@ async def test_auth_active_user_inactive(hass, aiohttp_client,
|
|||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token.token
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
|
@ -465,7 +467,7 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
|||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token.token
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
|
@ -484,12 +486,15 @@ async def test_call_service_context_with_user(hass, aiohttp_client,
|
|||
msg = await ws.receive_json()
|
||||
assert msg['success']
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
assert call.context.user_id == hass_access_token.refresh_token.user.id
|
||||
assert call.context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||
|
|
Loading…
Reference in New Issue