Reorg auth (#15443)
parent
23f1b49e55
commit
b6ca03ce47
|
@ -1,613 +0,0 @@
|
|||
"""Provide an authentication layer for Home Assistant."""
|
||||
import asyncio
|
||||
import binascii
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth'
|
||||
|
||||
AUTH_PROVIDERS = Registry()
|
||||
|
||||
AUTH_PROVIDER_SCHEMA = vol.Schema({
|
||||
vol.Required(CONF_TYPE): str,
|
||||
vol.Optional(CONF_NAME): str,
|
||||
# Specify ID if you have two auth providers for same type.
|
||||
vol.Optional(CONF_ID): str,
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
||||
DATA_REQS = 'auth_reqs_processed'
|
||||
|
||||
|
||||
def generate_secret(entropy: int = 32) -> str:
|
||||
"""Generate a secret.
|
||||
|
||||
Backport of secrets.token_hex from Python 3.6
|
||||
|
||||
Event loop friendly.
|
||||
"""
|
||||
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
|
||||
|
||||
|
||||
class AuthProvider:
|
||||
"""Provider of user authentication."""
|
||||
|
||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||
|
||||
initialized = False
|
||||
|
||||
def __init__(self, hass, store, config):
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
self.store = store
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def id(self): # pylint: disable=invalid-name
|
||||
"""Return id of the auth provider.
|
||||
|
||||
Optional, can be None.
|
||||
"""
|
||||
return self.config.get(CONF_ID)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return type of the provider."""
|
||||
return self.config[CONF_TYPE]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the auth provider."""
|
||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||
|
||||
async def async_credentials(self):
|
||||
"""Return all credentials of this provider."""
|
||||
users = await self.store.async_get_users()
|
||||
return [
|
||||
credentials
|
||||
for user in users
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == self.type and
|
||||
credentials.auth_provider_id == self.id)
|
||||
]
|
||||
|
||||
@callback
|
||||
def async_create_credentials(self, data):
|
||||
"""Create credentials."""
|
||||
return Credentials(
|
||||
auth_provider_type=self.type,
|
||||
auth_provider_id=self.id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_initialize(self):
|
||||
"""Initialize the auth provider.
|
||||
|
||||
Optional.
|
||||
"""
|
||||
|
||||
async def async_credential_flow(self):
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class User:
|
||||
"""A user."""
|
||||
|
||||
name = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_owner = attr.ib(type=bool, default=False)
|
||||
is_active = attr.ib(type=bool, default=False)
|
||||
system_generated = attr.ib(type=bool, default=False)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
# Tokens associated with a user.
|
||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class RefreshToken:
|
||||
"""RefreshToken for a user to grant new access tokens."""
|
||||
|
||||
user = attr.ib(type=User)
|
||||
client_id = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
access_token_expiration = attr.ib(type=timedelta,
|
||||
default=ACCESS_TOKEN_EXPIRATION)
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(lambda: generate_secret(64)))
|
||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class AccessToken:
|
||||
"""Access token to access the API.
|
||||
|
||||
These will only ever be stored in memory and not be persisted.
|
||||
"""
|
||||
|
||||
refresh_token = attr.ib(type=RefreshToken)
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(generate_secret))
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Return if this token has expired."""
|
||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
||||
return dt_util.utcnow() > expires
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Credentials:
|
||||
"""Credentials for a user on an auth provider."""
|
||||
|
||||
auth_provider_type = attr.ib(type=str)
|
||||
auth_provider_id = attr.ib(type=str)
|
||||
|
||||
# Allow the auth provider to store data to represent their auth.
|
||||
data = attr.ib(type=dict)
|
||||
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_new = attr.ib(type=bool, default=True)
|
||||
|
||||
|
||||
async def load_auth_provider_module(hass, provider):
|
||||
"""Load an auth provider."""
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
'homeassistant.auth_providers.{}'.format(provider))
|
||||
except ImportError:
|
||||
_LOGGER.warning('Unable to find auth provider %s', provider)
|
||||
return None
|
||||
|
||||
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
|
||||
return module
|
||||
|
||||
processed = hass.data.get(DATA_REQS)
|
||||
|
||||
if processed is None:
|
||||
processed = hass.data[DATA_REQS] = set()
|
||||
elif provider in processed:
|
||||
return module
|
||||
|
||||
req_success = await requirements.async_process_requirements(
|
||||
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
||||
|
||||
if not req_success:
|
||||
return None
|
||||
|
||||
return module
|
||||
|
||||
|
||||
async def auth_manager_from_config(hass, provider_configs):
|
||||
"""Initialize an auth manager from config."""
|
||||
store = AuthStore(hass)
|
||||
if provider_configs:
|
||||
providers = await asyncio.gather(
|
||||
*[_auth_provider_from_config(hass, store, config)
|
||||
for config in provider_configs])
|
||||
else:
|
||||
providers = []
|
||||
# So returned auth providers are in same order as config
|
||||
provider_hash = OrderedDict()
|
||||
for provider in providers:
|
||||
if provider is None:
|
||||
continue
|
||||
|
||||
key = (provider.type, provider.id)
|
||||
|
||||
if key in provider_hash:
|
||||
_LOGGER.error(
|
||||
'Found duplicate provider: %s. Please add unique IDs if you '
|
||||
'want to have the same provider twice.', key)
|
||||
continue
|
||||
|
||||
provider_hash[key] = provider
|
||||
manager = AuthManager(hass, store, provider_hash)
|
||||
return manager
|
||||
|
||||
|
||||
async def _auth_provider_from_config(hass, store, config):
|
||||
"""Initialize an auth provider from a config."""
|
||||
provider_name = config[CONF_TYPE]
|
||||
module = await load_auth_provider_module(hass, provider_name)
|
||||
|
||||
if module is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
config = module.CONFIG_SCHEMA(config)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
||||
provider_name, humanize_error(config, err))
|
||||
return None
|
||||
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Manage the authentication for Home Assistant."""
|
||||
|
||||
def __init__(self, hass, store, providers):
|
||||
"""Initialize the auth manager."""
|
||||
self._store = store
|
||||
self._providers = providers
|
||||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow,
|
||||
self._async_finish_login_flow)
|
||||
self._access_tokens = {}
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
"""Return if any auth providers are registered."""
|
||||
return bool(self._providers)
|
||||
|
||||
@property
|
||||
def support_legacy(self):
|
||||
"""
|
||||
Return if legacy_api_password auth providers are registered.
|
||||
|
||||
Should be removed when we removed legacy_api_password auth providers.
|
||||
"""
|
||||
for provider_type, _ in self._providers:
|
||||
if provider_type == 'legacy_api_password':
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def async_auth_providers(self):
|
||||
"""Return a list of available auth providers."""
|
||||
return self._providers.values()
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
return await self._store.async_get_user(user_id)
|
||||
|
||||
async def async_create_system_user(self, name):
|
||||
"""Create a system user."""
|
||||
return await self._store.async_create_user(
|
||||
name=name,
|
||||
system_generated=True,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
async def async_get_or_create_user(self, credentials):
|
||||
"""Get or create a user."""
|
||||
if not credentials.is_new:
|
||||
for user in await self._store.async_get_users():
|
||||
for creds in user.credentials:
|
||||
if creds.id == credentials.id:
|
||||
return user
|
||||
|
||||
raise ValueError('Unable to find the user.')
|
||||
|
||||
auth_provider = self._async_get_auth_provider(credentials)
|
||||
info = await auth_provider.async_user_meta_for_credentials(
|
||||
credentials)
|
||||
|
||||
kwargs = {
|
||||
'credentials': credentials,
|
||||
'name': info.get('name')
|
||||
}
|
||||
|
||||
# Make owner and activate user if it's the first user.
|
||||
if await self._store.async_get_users():
|
||||
kwargs['is_owner'] = False
|
||||
kwargs['is_active'] = False
|
||||
else:
|
||||
kwargs['is_owner'] = True
|
||||
kwargs['is_active'] = True
|
||||
|
||||
return await self._store.async_create_user(**kwargs)
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
"""Link credentials to an existing user."""
|
||||
await self._store.async_link_user(user, credentials)
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
await self._store.async_remove_user(user)
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
raise ValueError('User is not active')
|
||||
|
||||
if user.system_generated and client_id is not None:
|
||||
raise ValueError(
|
||||
'System generated users cannot have refresh tokens connected '
|
||||
'to a client.')
|
||||
|
||||
if not user.system_generated and client_id is None:
|
||||
raise ValueError('Client is required to generate a refresh token.')
|
||||
|
||||
return await self._store.async_create_refresh_token(user, client_id)
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token(token)
|
||||
|
||||
@callback
|
||||
def async_create_access_token(self, refresh_token):
|
||||
"""Create a new access token."""
|
||||
access_token = AccessToken(refresh_token=refresh_token)
|
||||
self._access_tokens[access_token.token] = access_token
|
||||
return access_token
|
||||
|
||||
@callback
|
||||
def async_get_access_token(self, token):
|
||||
"""Get an access token."""
|
||||
tkn = self._access_tokens.get(token)
|
||||
|
||||
if tkn is None:
|
||||
return None
|
||||
|
||||
if tkn.expired:
|
||||
self._access_tokens.pop(token)
|
||||
return None
|
||||
|
||||
return tkn
|
||||
|
||||
async def _async_create_login_flow(self, handler, *, source, data):
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
if not auth_provider.initialized:
|
||||
auth_provider.initialized = True
|
||||
await auth_provider.async_initialize()
|
||||
|
||||
return await auth_provider.async_credential_flow()
|
||||
|
||||
async def _async_finish_login_flow(self, result):
|
||||
"""Result of a credential login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
||||
auth_provider = self._providers[result['handler']]
|
||||
return await auth_provider.async_get_or_create_credentials(
|
||||
result['data'])
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(self, credentials):
|
||||
"""Helper to get auth provider from a set of credentials."""
|
||||
auth_provider_key = (credentials.auth_provider_type,
|
||||
credentials.auth_provider_id)
|
||||
return self._providers[auth_provider_key]
|
||||
|
||||
|
||||
class AuthStore:
|
||||
"""Stores authentication info.
|
||||
|
||||
Any mutation to an object should happen inside the auth store.
|
||||
|
||||
The auth store is lazy. It won't load the data from disk until a method is
|
||||
called that needs it.
|
||||
"""
|
||||
|
||||
def __init__(self, hass):
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def async_get_users(self):
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_create_user(self, name, is_owner=None, is_active=None,
|
||||
system_generated=None, credentials=None):
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
kwargs = {
|
||||
'name': name
|
||||
}
|
||||
|
||||
if is_owner is not None:
|
||||
kwargs['is_owner'] = is_owner
|
||||
|
||||
if is_active is not None:
|
||||
kwargs['is_active'] = is_active
|
||||
|
||||
if system_generated is not None:
|
||||
kwargs['system_generated'] = system_generated
|
||||
|
||||
new_user = User(**kwargs)
|
||||
|
||||
self._users[new_user.id] = new_user
|
||||
|
||||
if credentials is None:
|
||||
await self.async_save()
|
||||
return new_user
|
||||
|
||||
# Saving is done inside the link.
|
||||
await self.async_link_user(new_user, credentials)
|
||||
return new_user
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
"""Add credentials to an existing user."""
|
||||
user.credentials.append(credentials)
|
||||
await self.async_save()
|
||||
credentials.is_new = False
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new token for a user."""
|
||||
refresh_token = RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
await self.async_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
|
||||
return None
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
self._users = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
user_dict['id']: User(**user_dict) for user_dict in data['users']
|
||||
}
|
||||
|
||||
for cred_dict in data['credentials']:
|
||||
users[cred_dict['user_id']].credentials.append(Credentials(
|
||||
id=cred_dict['id'],
|
||||
is_new=False,
|
||||
auth_provider_type=cred_dict['auth_provider_type'],
|
||||
auth_provider_id=cred_dict['auth_provider_id'],
|
||||
data=cred_dict['data'],
|
||||
))
|
||||
|
||||
refresh_tokens = {}
|
||||
|
||||
for rt_dict in data['refresh_tokens']:
|
||||
token = RefreshToken(
|
||||
id=rt_dict['id'],
|
||||
user=users[rt_dict['user_id']],
|
||||
client_id=rt_dict['client_id'],
|
||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
||||
access_token_expiration=timedelta(
|
||||
seconds=rt_dict['access_token_expiration']),
|
||||
token=rt_dict['token'],
|
||||
)
|
||||
refresh_tokens[token.id] = token
|
||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
||||
|
||||
for ac_dict in data['access_tokens']:
|
||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
||||
token = AccessToken(
|
||||
refresh_token=refresh_token,
|
||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
||||
token=ac_dict['token'],
|
||||
)
|
||||
refresh_token.access_tokens.append(token)
|
||||
|
||||
self._users = users
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
users = [
|
||||
{
|
||||
'id': user.id,
|
||||
'is_owner': user.is_owner,
|
||||
'is_active': user.is_active,
|
||||
'name': user.name,
|
||||
'system_generated': user.system_generated,
|
||||
}
|
||||
for user in self._users.values()
|
||||
]
|
||||
|
||||
credentials = [
|
||||
{
|
||||
'id': credential.id,
|
||||
'user_id': user.id,
|
||||
'auth_provider_type': credential.auth_provider_type,
|
||||
'auth_provider_id': credential.auth_provider_id,
|
||||
'data': credential.data,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for credential in user.credentials
|
||||
]
|
||||
|
||||
refresh_tokens = [
|
||||
{
|
||||
'id': refresh_token.id,
|
||||
'user_id': user.id,
|
||||
'client_id': refresh_token.client_id,
|
||||
'created_at': refresh_token.created_at.isoformat(),
|
||||
'access_token_expiration':
|
||||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
access_tokens = [
|
||||
{
|
||||
'id': user.id,
|
||||
'refresh_token_id': refresh_token.id,
|
||||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
||||
data = {
|
||||
'users': users,
|
||||
'credentials': credentials,
|
||||
'access_tokens': access_tokens,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
}
|
||||
|
||||
await self._store.async_save(data, delay=1)
|
|
@ -0,0 +1,191 @@
|
|||
"""Provide an authentication layer for Home Assistant."""
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import models
|
||||
from . import auth_store
|
||||
from .providers import auth_provider_from_config
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_manager_from_config(hass, provider_configs):
|
||||
"""Initialize an auth manager from config."""
|
||||
store = auth_store.AuthStore(hass)
|
||||
if provider_configs:
|
||||
providers = await asyncio.gather(
|
||||
*[auth_provider_from_config(hass, store, config)
|
||||
for config in provider_configs])
|
||||
else:
|
||||
providers = []
|
||||
# So returned auth providers are in same order as config
|
||||
provider_hash = OrderedDict()
|
||||
for provider in providers:
|
||||
if provider is None:
|
||||
continue
|
||||
|
||||
key = (provider.type, provider.id)
|
||||
|
||||
if key in provider_hash:
|
||||
_LOGGER.error(
|
||||
'Found duplicate provider: %s. Please add unique IDs if you '
|
||||
'want to have the same provider twice.', key)
|
||||
continue
|
||||
|
||||
provider_hash[key] = provider
|
||||
manager = AuthManager(hass, store, provider_hash)
|
||||
return manager
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Manage the authentication for Home Assistant."""
|
||||
|
||||
def __init__(self, hass, store, providers):
|
||||
"""Initialize the auth manager."""
|
||||
self._store = store
|
||||
self._providers = providers
|
||||
self.login_flow = data_entry_flow.FlowManager(
|
||||
hass, self._async_create_login_flow,
|
||||
self._async_finish_login_flow)
|
||||
self._access_tokens = {}
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
"""Return if any auth providers are registered."""
|
||||
return bool(self._providers)
|
||||
|
||||
@property
|
||||
def support_legacy(self):
|
||||
"""
|
||||
Return if legacy_api_password auth providers are registered.
|
||||
|
||||
Should be removed when we removed legacy_api_password auth providers.
|
||||
"""
|
||||
for provider_type, _ in self._providers:
|
||||
if provider_type == 'legacy_api_password':
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def async_auth_providers(self):
|
||||
"""Return a list of available auth providers."""
|
||||
return self._providers.values()
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
return await self._store.async_get_user(user_id)
|
||||
|
||||
async def async_create_system_user(self, name):
|
||||
"""Create a system user."""
|
||||
return await self._store.async_create_user(
|
||||
name=name,
|
||||
system_generated=True,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
async def async_get_or_create_user(self, credentials):
|
||||
"""Get or create a user."""
|
||||
if not credentials.is_new:
|
||||
for user in await self._store.async_get_users():
|
||||
for creds in user.credentials:
|
||||
if creds.id == credentials.id:
|
||||
return user
|
||||
|
||||
raise ValueError('Unable to find the user.')
|
||||
|
||||
auth_provider = self._async_get_auth_provider(credentials)
|
||||
info = await auth_provider.async_user_meta_for_credentials(
|
||||
credentials)
|
||||
|
||||
kwargs = {
|
||||
'credentials': credentials,
|
||||
'name': info.get('name')
|
||||
}
|
||||
|
||||
# Make owner and activate user if it's the first user.
|
||||
if await self._store.async_get_users():
|
||||
kwargs['is_owner'] = False
|
||||
kwargs['is_active'] = False
|
||||
else:
|
||||
kwargs['is_owner'] = True
|
||||
kwargs['is_active'] = True
|
||||
|
||||
return await self._store.async_create_user(**kwargs)
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
"""Link credentials to an existing user."""
|
||||
await self._store.async_link_user(user, credentials)
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
await self._store.async_remove_user(user)
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
raise ValueError('User is not active')
|
||||
|
||||
if user.system_generated and client_id is not None:
|
||||
raise ValueError(
|
||||
'System generated users cannot have refresh tokens connected '
|
||||
'to a client.')
|
||||
|
||||
if not user.system_generated and client_id is None:
|
||||
raise ValueError('Client is required to generate a refresh token.')
|
||||
|
||||
return await self._store.async_create_refresh_token(user, client_id)
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token(token)
|
||||
|
||||
@callback
|
||||
def async_create_access_token(self, refresh_token):
|
||||
"""Create a new access token."""
|
||||
access_token = models.AccessToken(refresh_token=refresh_token)
|
||||
self._access_tokens[access_token.token] = access_token
|
||||
return access_token
|
||||
|
||||
@callback
|
||||
def async_get_access_token(self, token):
|
||||
"""Get an access token."""
|
||||
tkn = self._access_tokens.get(token)
|
||||
|
||||
if tkn is None:
|
||||
return None
|
||||
|
||||
if tkn.expired:
|
||||
self._access_tokens.pop(token)
|
||||
return None
|
||||
|
||||
return tkn
|
||||
|
||||
async def _async_create_login_flow(self, handler, *, source, data):
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
if not auth_provider.initialized:
|
||||
auth_provider.initialized = True
|
||||
await auth_provider.async_initialize()
|
||||
|
||||
return await auth_provider.async_credential_flow()
|
||||
|
||||
async def _async_finish_login_flow(self, result):
|
||||
"""Result of a credential login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
||||
auth_provider = self._providers[result['handler']]
|
||||
return await auth_provider.async_get_or_create_credentials(
|
||||
result['data'])
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(self, credentials):
|
||||
"""Helper to get auth provider from a set of credentials."""
|
||||
auth_provider_key = (credentials.auth_provider_type,
|
||||
credentials.auth_provider_id)
|
||||
return self._providers[auth_provider_key]
|
|
@ -0,0 +1,213 @@
|
|||
"""Storage for auth models."""
|
||||
from datetime import timedelta
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import models
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth'
|
||||
|
||||
|
||||
class AuthStore:
|
||||
"""Stores authentication info.
|
||||
|
||||
Any mutation to an object should happen inside the auth store.
|
||||
|
||||
The auth store is lazy. It won't load the data from disk until a method is
|
||||
called that needs it.
|
||||
"""
|
||||
|
||||
def __init__(self, hass):
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def async_get_users(self):
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_create_user(self, name, is_owner=None, is_active=None,
|
||||
system_generated=None, credentials=None):
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
kwargs = {
|
||||
'name': name
|
||||
}
|
||||
|
||||
if is_owner is not None:
|
||||
kwargs['is_owner'] = is_owner
|
||||
|
||||
if is_active is not None:
|
||||
kwargs['is_active'] = is_active
|
||||
|
||||
if system_generated is not None:
|
||||
kwargs['system_generated'] = system_generated
|
||||
|
||||
new_user = models.User(**kwargs)
|
||||
|
||||
self._users[new_user.id] = new_user
|
||||
|
||||
if credentials is None:
|
||||
await self.async_save()
|
||||
return new_user
|
||||
|
||||
# Saving is done inside the link.
|
||||
await self.async_link_user(new_user, credentials)
|
||||
return new_user
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
"""Add credentials to an existing user."""
|
||||
user.credentials.append(credentials)
|
||||
await self.async_save()
|
||||
credentials.is_new = False
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
"""Create a new token for a user."""
|
||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
await self.async_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
|
||||
return None
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
self._users = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
user_dict['id']: models.User(**user_dict)
|
||||
for user_dict in data['users']
|
||||
}
|
||||
|
||||
for cred_dict in data['credentials']:
|
||||
users[cred_dict['user_id']].credentials.append(models.Credentials(
|
||||
id=cred_dict['id'],
|
||||
is_new=False,
|
||||
auth_provider_type=cred_dict['auth_provider_type'],
|
||||
auth_provider_id=cred_dict['auth_provider_id'],
|
||||
data=cred_dict['data'],
|
||||
))
|
||||
|
||||
refresh_tokens = {}
|
||||
|
||||
for rt_dict in data['refresh_tokens']:
|
||||
token = models.RefreshToken(
|
||||
id=rt_dict['id'],
|
||||
user=users[rt_dict['user_id']],
|
||||
client_id=rt_dict['client_id'],
|
||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
||||
access_token_expiration=timedelta(
|
||||
seconds=rt_dict['access_token_expiration']),
|
||||
token=rt_dict['token'],
|
||||
)
|
||||
refresh_tokens[token.id] = token
|
||||
users[rt_dict['user_id']].refresh_tokens[token.token] = token
|
||||
|
||||
for ac_dict in data['access_tokens']:
|
||||
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
|
||||
token = models.AccessToken(
|
||||
refresh_token=refresh_token,
|
||||
created_at=dt_util.parse_datetime(ac_dict['created_at']),
|
||||
token=ac_dict['token'],
|
||||
)
|
||||
refresh_token.access_tokens.append(token)
|
||||
|
||||
self._users = users
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
users = [
|
||||
{
|
||||
'id': user.id,
|
||||
'is_owner': user.is_owner,
|
||||
'is_active': user.is_active,
|
||||
'name': user.name,
|
||||
'system_generated': user.system_generated,
|
||||
}
|
||||
for user in self._users.values()
|
||||
]
|
||||
|
||||
credentials = [
|
||||
{
|
||||
'id': credential.id,
|
||||
'user_id': user.id,
|
||||
'auth_provider_type': credential.auth_provider_type,
|
||||
'auth_provider_id': credential.auth_provider_id,
|
||||
'data': credential.data,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for credential in user.credentials
|
||||
]
|
||||
|
||||
refresh_tokens = [
|
||||
{
|
||||
'id': refresh_token.id,
|
||||
'user_id': user.id,
|
||||
'client_id': refresh_token.client_id,
|
||||
'created_at': refresh_token.created_at.isoformat(),
|
||||
'access_token_expiration':
|
||||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
access_tokens = [
|
||||
{
|
||||
'id': user.id,
|
||||
'refresh_token_id': refresh_token.id,
|
||||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
|
||||
data = {
|
||||
'users': users,
|
||||
'credentials': credentials,
|
||||
'access_tokens': access_tokens,
|
||||
'refresh_tokens': refresh_tokens,
|
||||
}
|
||||
|
||||
await self._store.async_save(data, delay=1)
|
|
@ -0,0 +1,4 @@
|
|||
"""Constants for the auth module."""
|
||||
from datetime import timedelta
|
||||
|
||||
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
|
|
@ -0,0 +1,75 @@
|
|||
"""Auth models."""
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
|
||||
import attr
|
||||
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import ACCESS_TOKEN_EXPIRATION
|
||||
from .util import generate_secret
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class User:
|
||||
"""A user."""
|
||||
|
||||
name = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_owner = attr.ib(type=bool, default=False)
|
||||
is_active = attr.ib(type=bool, default=False)
|
||||
system_generated = attr.ib(type=bool, default=False)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
# Tokens associated with a user.
|
||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class RefreshToken:
|
||||
"""RefreshToken for a user to grant new access tokens."""
|
||||
|
||||
user = attr.ib(type=User)
|
||||
client_id = attr.ib(type=str)
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
access_token_expiration = attr.ib(type=timedelta,
|
||||
default=ACCESS_TOKEN_EXPIRATION)
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(lambda: generate_secret(64)))
|
||||
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class AccessToken:
|
||||
"""Access token to access the API.
|
||||
|
||||
These will only ever be stored in memory and not be persisted.
|
||||
"""
|
||||
|
||||
refresh_token = attr.ib(type=RefreshToken)
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
token = attr.ib(type=str,
|
||||
default=attr.Factory(generate_secret))
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Return if this token has expired."""
|
||||
expires = self.created_at + self.refresh_token.access_token_expiration
|
||||
return dt_util.utcnow() > expires
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Credentials:
|
||||
"""Credentials for a user on an auth provider."""
|
||||
|
||||
auth_provider_type = attr.ib(type=str)
|
||||
auth_provider_id = attr.ib(type=str)
|
||||
|
||||
# Allow the auth provider to store data to represent their auth.
|
||||
data = attr.ib(type=dict)
|
||||
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_new = attr.ib(type=bool, default=True)
|
|
@ -0,0 +1,147 @@
|
|||
"""Auth providers for Home Assistant."""
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import requirements
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from homeassistant.auth.models import Credentials
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = 'auth_prov_reqs_processed'
|
||||
|
||||
AUTH_PROVIDERS = Registry()
|
||||
|
||||
AUTH_PROVIDER_SCHEMA = vol.Schema({
|
||||
vol.Required(CONF_TYPE): str,
|
||||
vol.Optional(CONF_NAME): str,
|
||||
# Specify ID if you have two auth providers for same type.
|
||||
vol.Optional(CONF_ID): str,
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
async def auth_provider_from_config(hass, store, config):
|
||||
"""Initialize an auth provider from a config."""
|
||||
provider_name = config[CONF_TYPE]
|
||||
module = await load_auth_provider_module(hass, provider_name)
|
||||
|
||||
if module is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
config = module.CONFIG_SCHEMA(config)
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
||||
provider_name, humanize_error(config, err))
|
||||
return None
|
||||
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||
|
||||
|
||||
async def load_auth_provider_module(hass, provider):
|
||||
"""Load an auth provider."""
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
'homeassistant.auth.providers.{}'.format(provider))
|
||||
except ImportError:
|
||||
_LOGGER.warning('Unable to find auth provider %s', provider)
|
||||
return None
|
||||
|
||||
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
|
||||
return module
|
||||
|
||||
processed = hass.data.get(DATA_REQS)
|
||||
|
||||
if processed is None:
|
||||
processed = hass.data[DATA_REQS] = set()
|
||||
elif provider in processed:
|
||||
return module
|
||||
|
||||
req_success = await requirements.async_process_requirements(
|
||||
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
||||
|
||||
if not req_success:
|
||||
return None
|
||||
|
||||
processed.add(provider)
|
||||
return module
|
||||
|
||||
|
||||
class AuthProvider:
|
||||
"""Provider of user authentication."""
|
||||
|
||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||
|
||||
initialized = False
|
||||
|
||||
def __init__(self, hass, store, config):
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
self.store = store
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def id(self): # pylint: disable=invalid-name
|
||||
"""Return id of the auth provider.
|
||||
|
||||
Optional, can be None.
|
||||
"""
|
||||
return self.config.get(CONF_ID)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return type of the provider."""
|
||||
return self.config[CONF_TYPE]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the auth provider."""
|
||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||
|
||||
async def async_credentials(self):
|
||||
"""Return all credentials of this provider."""
|
||||
users = await self.store.async_get_users()
|
||||
return [
|
||||
credentials
|
||||
for user in users
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == self.type and
|
||||
credentials.auth_provider_id == self.id)
|
||||
]
|
||||
|
||||
@callback
|
||||
def async_create_credentials(self, data):
|
||||
"""Create credentials."""
|
||||
return Credentials(
|
||||
auth_provider_type=self.type,
|
||||
auth_provider_id=self.id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_initialize(self):
|
||||
"""Initialize the auth provider.
|
||||
|
||||
Optional.
|
||||
"""
|
||||
|
||||
async def async_credential_flow(self):
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
"""
|
||||
return {}
|
|
@ -6,14 +6,17 @@ import hmac
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||
|
||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
|
||||
|
@ -43,7 +46,7 @@ class Data:
|
|||
|
||||
if data is None:
|
||||
data = {
|
||||
'salt': auth.generate_secret(),
|
||||
'salt': generate_secret(),
|
||||
'users': []
|
||||
}
|
||||
|
||||
|
@ -112,8 +115,8 @@ class Data:
|
|||
await self._store.async_save(self._data)
|
||||
|
||||
|
||||
@auth.AUTH_PROVIDERS.register('homeassistant')
|
||||
class HassAuthProvider(auth.AuthProvider):
|
||||
@AUTH_PROVIDERS.register('homeassistant')
|
||||
class HassAuthProvider(AuthProvider):
|
||||
"""Auth provider based on a local storage of users in HASS config dir."""
|
||||
|
||||
DEFAULT_TITLE = 'Home Assistant Local'
|
|
@ -5,9 +5,11 @@ import hmac
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
|
||||
|
||||
USER_SCHEMA = vol.Schema({
|
||||
vol.Required('username'): str,
|
||||
|
@ -16,7 +18,7 @@ USER_SCHEMA = vol.Schema({
|
|||
})
|
||||
|
||||
|
||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
vol.Required('users'): [USER_SCHEMA]
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
|
@ -25,8 +27,8 @@ class InvalidAuthError(HomeAssistantError):
|
|||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
@auth.AUTH_PROVIDERS.register('insecure_example')
|
||||
class ExampleAuthProvider(auth.AuthProvider):
|
||||
@AUTH_PROVIDERS.register('insecure_example')
|
||||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_credential_flow(self):
|
|
@ -9,15 +9,18 @@ import hmac
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
|
||||
|
||||
USER_SCHEMA = vol.Schema({
|
||||
vol.Required('username'): str,
|
||||
})
|
||||
|
||||
|
||||
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
||||
LEGACY_USER = 'homeassistant'
|
||||
|
@ -27,8 +30,8 @@ class InvalidAuthError(HomeAssistantError):
|
|||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
@auth.AUTH_PROVIDERS.register('legacy_api_password')
|
||||
class LegacyApiPasswordAuthProvider(auth.AuthProvider):
|
||||
@AUTH_PROVIDERS.register('legacy_api_password')
|
||||
class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
DEFAULT_TITLE = 'Legacy API Password'
|
|
@ -0,0 +1,13 @@
|
|||
"""Auth utils."""
|
||||
import binascii
|
||||
import os
|
||||
|
||||
|
||||
def generate_secret(entropy: int = 32) -> str:
|
||||
"""Generate a secret.
|
||||
|
||||
Backport of secrets.token_hex from Python 3.6
|
||||
|
||||
Event loop friendly.
|
||||
"""
|
||||
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
|
|
@ -1 +0,0 @@
|
|||
"""Auth providers for Home Assistant."""
|
|
@ -10,7 +10,7 @@ import logging
|
|||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth import generate_secret
|
||||
from homeassistant.auth.util import generate_secret
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
|
|
@ -13,6 +13,7 @@ import voluptuous as vol
|
|||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.auth import providers as auth_providers
|
||||
from homeassistant.const import (
|
||||
ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ASSUMED_STATE,
|
||||
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
|
||||
|
@ -159,7 +160,7 @@ CORE_CONFIG_SCHEMA = CUSTOMIZE_CONFIG_SCHEMA.extend({
|
|||
vol.All(cv.ensure_list, [vol.IsDir()]),
|
||||
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
|
||||
vol.Optional(CONF_AUTH_PROVIDERS):
|
||||
vol.All(cv.ensure_list, [auth.AUTH_PROVIDER_SCHEMA])
|
||||
vol.All(cv.ensure_list, [auth_providers.AUTH_PROVIDER_SCHEMA])
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config import get_default_config_dir
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
|
||||
|
||||
def run(args):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the Home Assistant auth module."""
|
|
@ -2,7 +2,7 @@
|
|||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
|
@ -4,8 +4,8 @@ import uuid
|
|||
|
||||
import pytest
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.auth_providers import insecure_example
|
||||
from homeassistant.auth import auth_store, models as auth_models
|
||||
from homeassistant.auth.providers import insecure_example
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
|
@ -13,7 +13,7 @@ from tests.common import mock_coro
|
|||
@pytest.fixture
|
||||
def store(hass):
|
||||
"""Mock store."""
|
||||
return auth.AuthStore(hass)
|
||||
return auth_store.AuthStore(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -45,7 +45,7 @@ async def test_create_new_credential(provider):
|
|||
|
||||
async def test_match_existing_credentials(store, provider):
|
||||
"""See if we match existing users."""
|
||||
existing = auth.Credentials(
|
||||
existing = auth_models.Credentials(
|
||||
id=uuid.uuid4(),
|
||||
auth_provider_type='insecure_example',
|
||||
auth_provider_id=None,
|
|
@ -4,13 +4,14 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from homeassistant import auth
|
||||
from homeassistant.auth_providers import legacy_api_password
|
||||
from homeassistant.auth import auth_store
|
||||
from homeassistant.auth.providers import legacy_api_password
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(hass):
|
||||
"""Mock store."""
|
||||
return auth.AuthStore(hass)
|
||||
return auth_store.AuthStore(hass)
|
||||
|
||||
|
||||
@pytest.fixture
|
|
@ -5,6 +5,8 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant import auth, data_entry_flow
|
||||
from homeassistant.auth import (
|
||||
models as auth_models, auth_store, const as auth_const)
|
||||
from homeassistant.util import dt as dt_util
|
||||
from tests.common import (
|
||||
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
|
||||
|
@ -101,7 +103,7 @@ async def test_login_as_existing_user(mock_hass):
|
|||
is_active=False,
|
||||
name='Not user',
|
||||
).add_to_auth_manager(manager)
|
||||
user.credentials.append(auth.Credentials(
|
||||
user.credentials.append(auth_models.Credentials(
|
||||
id='mock-id2',
|
||||
auth_provider_type='insecure_example',
|
||||
auth_provider_id=None,
|
||||
|
@ -116,7 +118,7 @@ async def test_login_as_existing_user(mock_hass):
|
|||
is_active=False,
|
||||
name='Paulus',
|
||||
).add_to_auth_manager(manager)
|
||||
user.credentials.append(auth.Credentials(
|
||||
user.credentials.append(auth_models.Credentials(
|
||||
id='mock-id',
|
||||
auth_provider_type='insecure_example',
|
||||
auth_provider_id=None,
|
||||
|
@ -203,7 +205,7 @@ async def test_saving_loading(hass, hass_storage):
|
|||
|
||||
await flush_store(manager._store._store)
|
||||
|
||||
store2 = auth.AuthStore(hass)
|
||||
store2 = auth_store.AuthStore(hass)
|
||||
users = await store2.async_get_users()
|
||||
assert len(users) == 1
|
||||
assert users[0] == user
|
||||
|
@ -211,23 +213,25 @@ async def test_saving_loading(hass, hass_storage):
|
|||
|
||||
def test_access_token_expired():
|
||||
"""Test that the expired property on access tokens work."""
|
||||
refresh_token = auth.RefreshToken(
|
||||
refresh_token = auth_models.RefreshToken(
|
||||
user=None,
|
||||
client_id='bla'
|
||||
)
|
||||
|
||||
access_token = auth.AccessToken(
|
||||
access_token = auth_models.AccessToken(
|
||||
refresh_token=refresh_token
|
||||
)
|
||||
|
||||
assert access_token.expired is False
|
||||
|
||||
with patch('homeassistant.auth.dt_util.utcnow',
|
||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=dt_util.utcnow() +
|
||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||
assert access_token.expired is True
|
||||
|
||||
almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
||||
with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp):
|
||||
almost_exp = \
|
||||
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
|
||||
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
|
||||
assert access_token.expired is False
|
||||
|
||||
|
||||
|
@ -242,8 +246,9 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
|||
access_token = manager.async_create_access_token(refresh_token)
|
||||
assert manager.async_get_access_token(access_token.token) is access_token
|
||||
|
||||
with patch('homeassistant.auth.dt_util.utcnow',
|
||||
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
|
||||
with patch('homeassistant.util.dt.utcnow',
|
||||
return_value=dt_util.utcnow() +
|
||||
auth_const.ACCESS_TOKEN_EXPIRATION):
|
||||
assert manager.async_get_access_token(access_token.token) is None
|
||||
|
||||
# Even with unpatched time, it should have been removed from manager
|
|
@ -12,6 +12,7 @@ import threading
|
|||
from contextlib import contextmanager
|
||||
|
||||
from homeassistant import auth, core as ha, data_entry_flow, config_entries
|
||||
from homeassistant.auth import models as auth_models, auth_store
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.helpers import (
|
||||
|
@ -114,7 +115,7 @@ def async_test_home_assistant(loop):
|
|||
"""Return a Home Assistant object pointing at test config dir."""
|
||||
hass = ha.HomeAssistant(loop)
|
||||
hass.config.async_load = Mock()
|
||||
store = auth.AuthStore(hass)
|
||||
store = auth_store.AuthStore(hass)
|
||||
hass.auth = auth.AuthManager(hass, store, {})
|
||||
ensure_auth_manager_loaded(hass.auth)
|
||||
INSTANCES.append(hass)
|
||||
|
@ -308,7 +309,7 @@ def mock_registry(hass, mock_entries=None):
|
|||
return registry
|
||||
|
||||
|
||||
class MockUser(auth.User):
|
||||
class MockUser(auth_models.User):
|
||||
"""Mock a user in Home Assistant."""
|
||||
|
||||
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from aiohttp import BasicAuth, web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
|
||||
from homeassistant.auth import AccessToken, RefreshToken
|
||||
from homeassistant.auth.models import AccessToken, RefreshToken
|
||||
from homeassistant.components.http.auth import setup_auth
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.real_ip import setup_real_ip
|
||||
|
|
|
@ -4,7 +4,7 @@ from unittest.mock import Mock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.scripts import auth as script_auth
|
||||
from homeassistant.auth_providers import homeassistant as hass_auth
|
||||
from homeassistant.auth.providers import homeassistant as hass_auth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
Loading…
Reference in New Issue