2018-07-13 09:43:08 +00:00
|
|
|
"""Storage for auth models."""
|
2018-07-13 13:31:20 +00:00
|
|
|
from collections import OrderedDict
|
2018-07-13 09:43:08 +00:00
|
|
|
from datetime import timedelta
|
2018-10-04 08:41:13 +00:00
|
|
|
import hmac
|
2018-08-16 20:25:41 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from typing import Any, Dict, List, Optional # noqa: F401
|
2018-07-13 09:43:08 +00:00
|
|
|
|
2018-09-11 10:05:15 +00:00
|
|
|
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
|
2018-08-17 18:18:21 +00:00
|
|
|
from homeassistant.core import HomeAssistant, callback
|
2018-07-13 09:43:08 +00:00
|
|
|
from homeassistant.util import dt as dt_util
|
|
|
|
|
|
|
|
from . import models
|
|
|
|
|
|
|
|
STORAGE_VERSION = 1
|
|
|
|
STORAGE_KEY = 'auth'
|
2018-10-08 14:35:38 +00:00
|
|
|
INITIAL_GROUP_NAME = 'All Access'
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
class AuthStore:
|
|
|
|
"""Stores authentication info.
|
|
|
|
|
|
|
|
Any mutation to an object should happen inside the auth store.
|
|
|
|
|
|
|
|
The auth store is lazy. It won't load the data from disk until a method is
|
|
|
|
called that needs it.
|
|
|
|
"""
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Initialize the auth store."""
|
|
|
|
self.hass = hass
|
2018-08-16 20:25:41 +00:00
|
|
|
self._users = None # type: Optional[Dict[str, models.User]]
|
2018-10-08 14:35:38 +00:00
|
|
|
self._groups = None # type: Optional[Dict[str, models.Group]]
|
2018-09-26 08:24:32 +00:00
|
|
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY,
|
|
|
|
private=True)
|
2018-07-13 09:43:08 +00:00
|
|
|
|
2018-10-08 14:35:38 +00:00
|
|
|
async def async_get_groups(self) -> List[models.Group]:
|
|
|
|
"""Retrieve all users."""
|
|
|
|
if self._groups is None:
|
|
|
|
await self._async_load()
|
|
|
|
assert self._groups is not None
|
|
|
|
|
|
|
|
return list(self._groups.values())
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_get_users(self) -> List[models.User]:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Retrieve all users."""
|
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
return list(self._users.values())
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Retrieve a user by id."""
|
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
return self._users.get(user_id)
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_create_user(
|
|
|
|
self, name: Optional[str], is_owner: Optional[bool] = None,
|
|
|
|
is_active: Optional[bool] = None,
|
|
|
|
system_generated: Optional[bool] = None,
|
2018-10-08 14:35:38 +00:00
|
|
|
credentials: Optional[models.Credentials] = None,
|
|
|
|
groups: Optional[List[models.Group]] = None) -> models.User:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Create a new user."""
|
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-10-08 14:35:38 +00:00
|
|
|
|
|
|
|
assert self._users is not None
|
|
|
|
assert self._groups is not None
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
kwargs = {
|
2018-10-08 14:35:38 +00:00
|
|
|
'name': name,
|
|
|
|
# Until we get group management, we just put everyone in the
|
|
|
|
# same group.
|
|
|
|
'groups': groups or [],
|
2018-08-16 20:25:41 +00:00
|
|
|
} # type: Dict[str, Any]
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
if is_owner is not None:
|
|
|
|
kwargs['is_owner'] = is_owner
|
|
|
|
|
|
|
|
if is_active is not None:
|
|
|
|
kwargs['is_active'] = is_active
|
|
|
|
|
|
|
|
if system_generated is not None:
|
|
|
|
kwargs['system_generated'] = system_generated
|
|
|
|
|
|
|
|
new_user = models.User(**kwargs)
|
|
|
|
|
|
|
|
self._users[new_user.id] = new_user
|
|
|
|
|
|
|
|
if credentials is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-13 09:43:08 +00:00
|
|
|
return new_user
|
|
|
|
|
|
|
|
# Saving is done inside the link.
|
|
|
|
await self.async_link_user(new_user, credentials)
|
|
|
|
return new_user
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_link_user(self, user: models.User,
|
|
|
|
credentials: models.Credentials) -> None:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Add credentials to an existing user."""
|
|
|
|
user.credentials.append(credentials)
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-13 09:43:08 +00:00
|
|
|
credentials.is_new = False
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_remove_user(self, user: models.User) -> None:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Remove a user."""
|
2018-08-16 20:25:41 +00:00
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
|
|
|
|
2018-07-13 09:43:08 +00:00
|
|
|
self._users.pop(user.id)
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-13 09:43:08 +00:00
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_activate_user(self, user: models.User) -> None:
|
2018-07-15 18:46:15 +00:00
|
|
|
"""Activate a user."""
|
|
|
|
user.is_active = True
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-15 18:46:15 +00:00
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_deactivate_user(self, user: models.User) -> None:
|
2018-07-15 18:46:15 +00:00
|
|
|
"""Activate a user."""
|
|
|
|
user.is_active = False
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-15 18:46:15 +00:00
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_remove_credentials(
|
|
|
|
self, credentials: models.Credentials) -> None:
|
2018-07-13 13:31:20 +00:00
|
|
|
"""Remove credentials."""
|
2018-08-16 20:25:41 +00:00
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
|
|
|
|
2018-07-13 13:31:20 +00:00
|
|
|
for user in self._users.values():
|
|
|
|
found = None
|
|
|
|
|
|
|
|
for index, cred in enumerate(user.credentials):
|
|
|
|
if cred is credentials:
|
|
|
|
found = index
|
|
|
|
break
|
|
|
|
|
|
|
|
if found is not None:
|
|
|
|
user.credentials.pop(found)
|
|
|
|
break
|
|
|
|
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-13 13:31:20 +00:00
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_create_refresh_token(
|
2018-09-11 10:05:15 +00:00
|
|
|
self, user: models.User, client_id: Optional[str] = None,
|
|
|
|
client_name: Optional[str] = None,
|
|
|
|
client_icon: Optional[str] = None,
|
|
|
|
token_type: str = models.TOKEN_TYPE_NORMAL,
|
|
|
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION) \
|
2018-08-16 20:25:41 +00:00
|
|
|
-> models.RefreshToken:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Create a new token for a user."""
|
2018-09-11 10:05:15 +00:00
|
|
|
kwargs = {
|
|
|
|
'user': user,
|
|
|
|
'client_id': client_id,
|
|
|
|
'token_type': token_type,
|
|
|
|
'access_token_expiration': access_token_expiration
|
|
|
|
} # type: Dict[str, Any]
|
|
|
|
if client_name:
|
|
|
|
kwargs['client_name'] = client_name
|
|
|
|
if client_icon:
|
|
|
|
kwargs['client_icon'] = client_icon
|
|
|
|
|
|
|
|
refresh_token = models.RefreshToken(**kwargs)
|
2018-08-14 19:14:12 +00:00
|
|
|
user.refresh_tokens[refresh_token.id] = refresh_token
|
2018-09-11 10:05:15 +00:00
|
|
|
|
2018-08-17 18:18:21 +00:00
|
|
|
self._async_schedule_save()
|
2018-07-13 09:43:08 +00:00
|
|
|
return refresh_token
|
|
|
|
|
2018-08-21 18:02:55 +00:00
|
|
|
async def async_remove_refresh_token(
|
|
|
|
self, refresh_token: models.RefreshToken) -> None:
|
|
|
|
"""Remove a refresh token."""
|
|
|
|
if self._users is None:
|
|
|
|
await self._async_load()
|
|
|
|
assert self._users is not None
|
|
|
|
|
|
|
|
for user in self._users.values():
|
|
|
|
if user.refresh_tokens.pop(refresh_token.id, None):
|
|
|
|
self._async_schedule_save()
|
|
|
|
break
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_get_refresh_token(
|
|
|
|
self, token_id: str) -> Optional[models.RefreshToken]:
|
2018-08-14 19:14:12 +00:00
|
|
|
"""Get refresh token by id."""
|
2018-07-13 09:43:08 +00:00
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
for user in self._users.values():
|
2018-08-14 19:14:12 +00:00
|
|
|
refresh_token = user.refresh_tokens.get(token_id)
|
2018-07-13 09:43:08 +00:00
|
|
|
if refresh_token is not None:
|
|
|
|
return refresh_token
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
async def async_get_refresh_token_by_token(
|
|
|
|
self, token: str) -> Optional[models.RefreshToken]:
|
2018-08-14 19:14:12 +00:00
|
|
|
"""Get refresh token by token."""
|
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
await self._async_load()
|
2018-08-16 20:25:41 +00:00
|
|
|
assert self._users is not None
|
2018-08-14 19:14:12 +00:00
|
|
|
|
|
|
|
found = None
|
|
|
|
|
|
|
|
for user in self._users.values():
|
|
|
|
for refresh_token in user.refresh_tokens.values():
|
|
|
|
if hmac.compare_digest(refresh_token.token, token):
|
|
|
|
found = refresh_token
|
|
|
|
|
|
|
|
return found
|
|
|
|
|
2018-09-12 11:24:16 +00:00
|
|
|
@callback
|
|
|
|
def async_log_refresh_token_usage(
|
|
|
|
self, refresh_token: models.RefreshToken,
|
|
|
|
remote_ip: Optional[str] = None) -> None:
|
|
|
|
"""Update refresh token last used information."""
|
|
|
|
refresh_token.last_used_at = dt_util.utcnow()
|
|
|
|
refresh_token.last_used_ip = remote_ip
|
|
|
|
self._async_schedule_save()
|
|
|
|
|
2018-08-17 18:18:21 +00:00
|
|
|
async def _async_load(self) -> None:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Load the users."""
|
|
|
|
data = await self._store.async_load()
|
|
|
|
|
|
|
|
# Make sure that we're not overriding data if 2 loads happened at the
|
|
|
|
# same time
|
|
|
|
if self._users is not None:
|
|
|
|
return
|
|
|
|
|
|
|
|
if data is None:
|
2018-10-04 08:41:13 +00:00
|
|
|
self._set_defaults()
|
2018-07-13 09:43:08 +00:00
|
|
|
return
|
|
|
|
|
2018-10-04 08:41:13 +00:00
|
|
|
users = OrderedDict() # type: Dict[str, models.User]
|
2018-10-08 14:35:38 +00:00
|
|
|
groups = OrderedDict() # type: Dict[str, models.Group]
|
2018-10-04 08:41:13 +00:00
|
|
|
|
|
|
|
# When creating objects we mention each attribute explicetely. This
|
|
|
|
# prevents crashing if user rolls back HA version after a new property
|
|
|
|
# was added.
|
|
|
|
|
2018-10-08 14:35:38 +00:00
|
|
|
for group_dict in data.get('groups', []):
|
|
|
|
groups[group_dict['id']] = models.Group(
|
|
|
|
name=group_dict['name'],
|
|
|
|
id=group_dict['id'],
|
|
|
|
)
|
|
|
|
|
|
|
|
migrate_group = None
|
|
|
|
|
|
|
|
if not groups:
|
|
|
|
migrate_group = models.Group(name=INITIAL_GROUP_NAME)
|
|
|
|
groups[migrate_group.id] = migrate_group
|
|
|
|
|
2018-07-13 13:31:20 +00:00
|
|
|
for user_dict in data['users']:
|
2018-10-04 08:41:13 +00:00
|
|
|
users[user_dict['id']] = models.User(
|
|
|
|
name=user_dict['name'],
|
2018-10-08 14:35:38 +00:00
|
|
|
groups=[groups[group_id] for group_id
|
|
|
|
in user_dict.get('group_ids', [])],
|
2018-10-04 08:41:13 +00:00
|
|
|
id=user_dict['id'],
|
|
|
|
is_owner=user_dict['is_owner'],
|
|
|
|
is_active=user_dict['is_active'],
|
|
|
|
system_generated=user_dict['system_generated'],
|
|
|
|
)
|
2018-10-08 14:35:38 +00:00
|
|
|
if migrate_group is not None and not user_dict['system_generated']:
|
|
|
|
users[user_dict['id']].groups = [migrate_group]
|
2018-07-13 09:43:08 +00:00
|
|
|
|
|
|
|
for cred_dict in data['credentials']:
|
|
|
|
users[cred_dict['user_id']].credentials.append(models.Credentials(
|
|
|
|
id=cred_dict['id'],
|
|
|
|
is_new=False,
|
|
|
|
auth_provider_type=cred_dict['auth_provider_type'],
|
|
|
|
auth_provider_id=cred_dict['auth_provider_id'],
|
|
|
|
data=cred_dict['data'],
|
|
|
|
))
|
|
|
|
|
|
|
|
for rt_dict in data['refresh_tokens']:
|
2018-08-14 19:14:12 +00:00
|
|
|
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
|
|
|
if 'jwt_key' not in rt_dict:
|
|
|
|
continue
|
|
|
|
|
2018-08-16 20:25:41 +00:00
|
|
|
created_at = dt_util.parse_datetime(rt_dict['created_at'])
|
|
|
|
if created_at is None:
|
|
|
|
getLogger(__name__).error(
|
|
|
|
'Ignoring refresh token %(id)s with invalid created_at '
|
|
|
|
'%(created_at)s for user_id %(user_id)s', rt_dict)
|
|
|
|
continue
|
2018-09-12 11:24:16 +00:00
|
|
|
|
2018-09-11 10:05:15 +00:00
|
|
|
token_type = rt_dict.get('token_type')
|
|
|
|
if token_type is None:
|
2018-09-11 10:55:05 +00:00
|
|
|
if rt_dict['client_id'] is None:
|
2018-09-11 10:05:15 +00:00
|
|
|
token_type = models.TOKEN_TYPE_SYSTEM
|
|
|
|
else:
|
|
|
|
token_type = models.TOKEN_TYPE_NORMAL
|
2018-09-12 11:24:16 +00:00
|
|
|
|
|
|
|
# old refresh_token don't have last_used_at (pre-0.78)
|
|
|
|
last_used_at_str = rt_dict.get('last_used_at')
|
|
|
|
if last_used_at_str:
|
|
|
|
last_used_at = dt_util.parse_datetime(last_used_at_str)
|
|
|
|
else:
|
|
|
|
last_used_at = None
|
|
|
|
|
2018-07-13 09:43:08 +00:00
|
|
|
token = models.RefreshToken(
|
|
|
|
id=rt_dict['id'],
|
|
|
|
user=users[rt_dict['user_id']],
|
|
|
|
client_id=rt_dict['client_id'],
|
2018-09-11 10:05:15 +00:00
|
|
|
# use dict.get to keep backward compatibility
|
|
|
|
client_name=rt_dict.get('client_name'),
|
|
|
|
client_icon=rt_dict.get('client_icon'),
|
|
|
|
token_type=token_type,
|
2018-08-16 20:25:41 +00:00
|
|
|
created_at=created_at,
|
2018-07-13 09:43:08 +00:00
|
|
|
access_token_expiration=timedelta(
|
|
|
|
seconds=rt_dict['access_token_expiration']),
|
|
|
|
token=rt_dict['token'],
|
2018-09-12 11:24:16 +00:00
|
|
|
jwt_key=rt_dict['jwt_key'],
|
|
|
|
last_used_at=last_used_at,
|
|
|
|
last_used_ip=rt_dict.get('last_used_ip'),
|
2018-07-13 09:43:08 +00:00
|
|
|
)
|
2018-08-14 19:14:12 +00:00
|
|
|
users[rt_dict['user_id']].refresh_tokens[token.id] = token
|
2018-07-13 09:43:08 +00:00
|
|
|
|
2018-10-08 14:35:38 +00:00
|
|
|
self._groups = groups
|
2018-07-13 09:43:08 +00:00
|
|
|
self._users = users
|
|
|
|
|
2018-08-17 18:18:21 +00:00
|
|
|
@callback
|
|
|
|
def _async_schedule_save(self) -> None:
|
2018-07-13 09:43:08 +00:00
|
|
|
"""Save users."""
|
2018-08-16 20:25:41 +00:00
|
|
|
if self._users is None:
|
2018-08-17 18:18:21 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
self._store.async_delay_save(self._data_to_save, 1)
|
|
|
|
|
|
|
|
@callback
|
|
|
|
def _data_to_save(self) -> Dict:
|
|
|
|
"""Return the data to store."""
|
|
|
|
assert self._users is not None
|
2018-10-08 14:35:38 +00:00
|
|
|
assert self._groups is not None
|
2018-08-16 20:25:41 +00:00
|
|
|
|
2018-07-13 09:43:08 +00:00
|
|
|
users = [
|
|
|
|
{
|
|
|
|
'id': user.id,
|
2018-10-08 14:35:38 +00:00
|
|
|
'group_ids': [group.id for group in user.groups],
|
2018-07-13 09:43:08 +00:00
|
|
|
'is_owner': user.is_owner,
|
|
|
|
'is_active': user.is_active,
|
|
|
|
'name': user.name,
|
|
|
|
'system_generated': user.system_generated,
|
|
|
|
}
|
|
|
|
for user in self._users.values()
|
|
|
|
]
|
|
|
|
|
2018-10-08 14:35:38 +00:00
|
|
|
groups = [
|
|
|
|
{
|
|
|
|
'name': group.name,
|
|
|
|
'id': group.id,
|
|
|
|
}
|
|
|
|
for group in self._groups.values()
|
|
|
|
]
|
|
|
|
|
2018-07-13 09:43:08 +00:00
|
|
|
credentials = [
|
|
|
|
{
|
|
|
|
'id': credential.id,
|
|
|
|
'user_id': user.id,
|
|
|
|
'auth_provider_type': credential.auth_provider_type,
|
|
|
|
'auth_provider_id': credential.auth_provider_id,
|
|
|
|
'data': credential.data,
|
|
|
|
}
|
|
|
|
for user in self._users.values()
|
|
|
|
for credential in user.credentials
|
|
|
|
]
|
|
|
|
|
|
|
|
refresh_tokens = [
|
|
|
|
{
|
|
|
|
'id': refresh_token.id,
|
|
|
|
'user_id': user.id,
|
|
|
|
'client_id': refresh_token.client_id,
|
2018-09-11 10:05:15 +00:00
|
|
|
'client_name': refresh_token.client_name,
|
|
|
|
'client_icon': refresh_token.client_icon,
|
|
|
|
'token_type': refresh_token.token_type,
|
2018-07-13 09:43:08 +00:00
|
|
|
'created_at': refresh_token.created_at.isoformat(),
|
|
|
|
'access_token_expiration':
|
|
|
|
refresh_token.access_token_expiration.total_seconds(),
|
|
|
|
'token': refresh_token.token,
|
2018-08-14 19:14:12 +00:00
|
|
|
'jwt_key': refresh_token.jwt_key,
|
2018-09-12 11:24:16 +00:00
|
|
|
'last_used_at':
|
|
|
|
refresh_token.last_used_at.isoformat()
|
|
|
|
if refresh_token.last_used_at else None,
|
|
|
|
'last_used_ip': refresh_token.last_used_ip,
|
2018-07-13 09:43:08 +00:00
|
|
|
}
|
|
|
|
for user in self._users.values()
|
|
|
|
for refresh_token in user.refresh_tokens.values()
|
|
|
|
]
|
|
|
|
|
2018-08-17 18:18:21 +00:00
|
|
|
return {
|
2018-07-13 09:43:08 +00:00
|
|
|
'users': users,
|
2018-10-08 14:35:38 +00:00
|
|
|
'groups': groups,
|
2018-07-13 09:43:08 +00:00
|
|
|
'credentials': credentials,
|
|
|
|
'refresh_tokens': refresh_tokens,
|
|
|
|
}
|
2018-10-04 08:41:13 +00:00
|
|
|
|
|
|
|
def _set_defaults(self) -> None:
|
|
|
|
"""Set default values for auth store."""
|
|
|
|
self._users = OrderedDict() # type: Dict[str, models.User]
|
2018-10-08 14:35:38 +00:00
|
|
|
|
|
|
|
# Add default group
|
|
|
|
all_access_group = models.Group(name=INITIAL_GROUP_NAME)
|
|
|
|
|
|
|
|
groups = OrderedDict() # type: Dict[str, models.Group]
|
|
|
|
groups[all_access_group.id] = all_access_group
|
|
|
|
|
|
|
|
self._groups = groups
|