core/homeassistant/auth/providers/__init__.py

248 lines
8.1 KiB
Python
Raw Normal View History

2018-07-13 09:43:08 +00:00
"""Auth providers for Home Assistant."""
import importlib
import logging
import types
from typing import Any, Dict, List, Optional
2018-07-13 09:43:08 +00:00
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback, HomeAssistant
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.util import dt as dt_util
2018-07-13 09:43:08 +00:00
from homeassistant.util.decorator import Registry
from ..auth_store import AuthStore
from ..models import Credentials, User, UserMeta # noqa: F401
from ..mfa_modules import SESSION_EXPIRATION
2018-07-13 09:43:08 +00:00
_LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed'
AUTH_PROVIDERS = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema({
vol.Required(CONF_TYPE): str,
vol.Optional(CONF_NAME): str,
# Specify ID if you have two auth providers for same type.
vol.Optional(CONF_ID): str,
}, extra=vol.ALLOW_EXTRA)
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
def __init__(self, hass: HomeAssistant, store: AuthStore,
config: Dict[str, Any]) -> None:
2018-07-13 09:43:08 +00:00
"""Initialize an auth provider."""
self.hass = hass
self.store = store
self.config = config
@property
def id(self) -> Optional[str]: # pylint: disable=invalid-name
2018-07-13 09:43:08 +00:00
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self) -> str:
2018-07-13 09:43:08 +00:00
"""Return type of the provider."""
return self.config[CONF_TYPE] # type: ignore
2018-07-13 09:43:08 +00:00
@property
def name(self) -> str:
2018-07-13 09:43:08 +00:00
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
@property
def support_mfa(self) -> bool:
"""Return whether multi-factor auth supported by the auth provider."""
return True
async def async_credentials(self) -> List[Credentials]:
2018-07-13 09:43:08 +00:00
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]
@callback
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
2018-07-13 09:43:08 +00:00
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow':
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
"""
2018-07-13 09:43:08 +00:00
raise NotImplementedError
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
2018-07-13 09:43:08 +00:00
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
2018-07-13 09:43:08 +00:00
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
raise NotImplementedError
async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore,
config: Dict[str, Any]) -> Optional[AuthProvider]:
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
if module is None:
return None
try:
config = module.CONFIG_SCHEMA(config) # type: ignore
except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s',
provider_name, humanize_error(config, err))
return None
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore
async def load_auth_provider_module(
hass: HomeAssistant, provider: str) -> Optional[types.ModuleType]:
"""Load an auth provider."""
try:
module = importlib.import_module(
'homeassistant.auth.providers.{}'.format(provider))
except ImportError:
_LOGGER.warning('Unable to find auth provider %s', provider)
return None
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
return module
processed = hass.data.get(DATA_REQS)
if processed is None:
processed = hass.data[DATA_REQS] = set()
elif provider in processed:
return module
# https://github.com/python/mypy/issues/1424
reqs = module.REQUIREMENTS # type: ignore
req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), reqs)
if not req_success:
return None
processed.add(provider)
return module
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._auth_module_id = None # type: Optional[str]
self._auth_manager = auth_provider.hass.auth # type: ignore
self.available_mfa_modules = [] # type: List
self.created_at = dt_util.utcnow()
self.user = None # type: Optional[User]
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input == None.
Return await self.async_finish(flow_result) if login init step pass.
"""
raise NotImplementedError
async def async_step_select_mfa_module(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the step of select mfa module."""
errors = {}
if user_input is not None:
auth_module = user_input.get('multi_factor_auth_module')
if auth_module in self.available_mfa_modules:
self._auth_module_id = auth_module
return await self.async_step_mfa()
errors['base'] = 'invalid_auth_module'
if len(self.available_mfa_modules) == 1:
self._auth_module_id = self.available_mfa_modules[0]
return await self.async_step_mfa()
return self.async_show_form(
step_id='select_mfa_module',
data_schema=vol.Schema({
'multi_factor_auth_module': vol.In(self.available_mfa_modules)
}),
errors=errors,
)
async def async_step_mfa(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the step of mfa validation."""
errors = {}
auth_module = self._auth_manager.get_auth_mfa_module(
self._auth_module_id)
if auth_module is None:
# Given an invalid input to async_step_select_mfa_module
# will show invalid_auth_module error
return await self.async_step_select_mfa_module(user_input={})
if user_input is not None:
expires = self.created_at + SESSION_EXPIRATION
if dt_util.utcnow() > expires:
errors['base'] = 'login_expired'
else:
result = await auth_module.async_validation(
self.user.id, user_input) # type: ignore
if not result:
errors['base'] = 'invalid_auth'
if not errors:
return await self.async_finish(self.user)
return self.async_show_form(
step_id='mfa',
data_schema=auth_module.input_schema,
errors=errors,
)
async def async_finish(self, flow_result: Any) -> Dict:
"""Handle the pass of login flow."""
return self.async_create_entry(
title=self._auth_provider.name,
data=flow_result
)