Add type hints to homeassistant.auth (#15853)
* Always load users in auth store before use * Use namedtuple instead of dict for user meta * Ignore auth store tokens with invalid created_at * Add type hints to homeassistant.authpull/16021/head
parent
e9e5bce10c
commit
649f17fe47
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import List, Awaitable
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -10,15 +10,17 @@ from homeassistant import data_entry_flow
|
|||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import auth_store
|
||||
from .providers import auth_provider_from_config
|
||||
from . import auth_store, models
|
||||
from .providers import auth_provider_from_config, AuthProvider
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_ProviderKey = Tuple[str, Optional[str]]
|
||||
_ProviderDict = Dict[_ProviderKey, AuthProvider]
|
||||
|
||||
|
||||
async def auth_manager_from_config(
|
||||
hass: HomeAssistant,
|
||||
provider_configs: List[dict]) -> Awaitable['AuthManager']:
|
||||
provider_configs: List[Dict[str, Any]]) -> 'AuthManager':
|
||||
"""Initialize an auth manager from config."""
|
||||
store = auth_store.AuthStore(hass)
|
||||
if provider_configs:
|
||||
|
@ -26,9 +28,9 @@ async def auth_manager_from_config(
|
|||
*[auth_provider_from_config(hass, store, config)
|
||||
for config in provider_configs])
|
||||
else:
|
||||
providers = []
|
||||
providers = ()
|
||||
# So returned auth providers are in same order as config
|
||||
provider_hash = OrderedDict()
|
||||
provider_hash = OrderedDict() # type: _ProviderDict
|
||||
for provider in providers:
|
||||
if provider is None:
|
||||
continue
|
||||
|
@ -49,7 +51,8 @@ async def auth_manager_from_config(
|
|||
class AuthManager:
|
||||
"""Manage the authentication for Home Assistant."""
|
||||
|
||||
def __init__(self, hass, store, providers):
|
||||
def __init__(self, hass: HomeAssistant, store: auth_store.AuthStore,
|
||||
providers: _ProviderDict) -> None:
|
||||
"""Initialize the auth manager."""
|
||||
self._store = store
|
||||
self._providers = providers
|
||||
|
@ -58,12 +61,12 @@ class AuthManager:
|
|||
self._async_finish_login_flow)
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
def active(self) -> bool:
|
||||
"""Return if any auth providers are registered."""
|
||||
return bool(self._providers)
|
||||
|
||||
@property
|
||||
def support_legacy(self):
|
||||
def support_legacy(self) -> bool:
|
||||
"""
|
||||
Return if legacy_api_password auth providers are registered.
|
||||
|
||||
|
@ -75,19 +78,19 @@ class AuthManager:
|
|||
return False
|
||||
|
||||
@property
|
||||
def auth_providers(self):
|
||||
def auth_providers(self) -> List[AuthProvider]:
|
||||
"""Return a list of available auth providers."""
|
||||
return list(self._providers.values())
|
||||
|
||||
async def async_get_users(self):
|
||||
async def async_get_users(self) -> List[models.User]:
|
||||
"""Retrieve all users."""
|
||||
return await self._store.async_get_users()
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||
"""Retrieve a user."""
|
||||
return await self._store.async_get_user(user_id)
|
||||
|
||||
async def async_create_system_user(self, name):
|
||||
async def async_create_system_user(self, name: str) -> models.User:
|
||||
"""Create a system user."""
|
||||
return await self._store.async_create_user(
|
||||
name=name,
|
||||
|
@ -95,19 +98,20 @@ class AuthManager:
|
|||
is_active=True,
|
||||
)
|
||||
|
||||
async def async_create_user(self, name):
|
||||
async def async_create_user(self, name: str) -> models.User:
|
||||
"""Create a user."""
|
||||
kwargs = {
|
||||
'name': name,
|
||||
'is_active': True,
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if await self._user_should_be_owner():
|
||||
kwargs['is_owner'] = True
|
||||
|
||||
return await self._store.async_create_user(**kwargs)
|
||||
|
||||
async def async_get_or_create_user(self, credentials):
|
||||
async def async_get_or_create_user(self, credentials: models.Credentials) \
|
||||
-> models.User:
|
||||
"""Get or create a user."""
|
||||
if not credentials.is_new:
|
||||
for user in await self._store.async_get_users():
|
||||
|
@ -127,15 +131,16 @@ class AuthManager:
|
|||
|
||||
return await self._store.async_create_user(
|
||||
credentials=credentials,
|
||||
name=info.get('name'),
|
||||
is_active=info.get('is_active', False)
|
||||
name=info.name,
|
||||
is_active=info.is_active,
|
||||
)
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
async def async_link_user(self, user: models.User,
|
||||
credentials: models.Credentials) -> None:
|
||||
"""Link credentials to an existing user."""
|
||||
await self._store.async_link_user(user, credentials)
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
async def async_remove_user(self, user: models.User) -> None:
|
||||
"""Remove a user."""
|
||||
tasks = [
|
||||
self.async_remove_credentials(credentials)
|
||||
|
@ -147,27 +152,32 @@ class AuthManager:
|
|||
|
||||
await self._store.async_remove_user(user)
|
||||
|
||||
async def async_activate_user(self, user):
|
||||
async def async_activate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
await self._store.async_activate_user(user)
|
||||
|
||||
async def async_deactivate_user(self, user):
|
||||
async def async_deactivate_user(self, user: models.User) -> None:
|
||||
"""Deactivate a user."""
|
||||
if user.is_owner:
|
||||
raise ValueError('Unable to deactive the owner')
|
||||
await self._store.async_deactivate_user(user)
|
||||
|
||||
async def async_remove_credentials(self, credentials):
|
||||
async def async_remove_credentials(
|
||||
self, credentials: models.Credentials) -> None:
|
||||
"""Remove credentials."""
|
||||
provider = self._async_get_auth_provider(credentials)
|
||||
|
||||
if (provider is not None and
|
||||
hasattr(provider, 'async_will_remove_credentials')):
|
||||
await provider.async_will_remove_credentials(credentials)
|
||||
# https://github.com/python/mypy/issues/1424
|
||||
await provider.async_will_remove_credentials( # type: ignore
|
||||
credentials)
|
||||
|
||||
await self._store.async_remove_credentials(credentials)
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
async def async_create_refresh_token(self, user: models.User,
|
||||
client_id: Optional[str] = None) \
|
||||
-> models.RefreshToken:
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
raise ValueError('User is not active')
|
||||
|
@ -182,16 +192,19 @@ class AuthManager:
|
|||
|
||||
return await self._store.async_create_refresh_token(user, client_id)
|
||||
|
||||
async def async_get_refresh_token(self, token_id):
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by id."""
|
||||
return await self._store.async_get_refresh_token(token_id)
|
||||
|
||||
async def async_get_refresh_token_by_token(self, token):
|
||||
async def async_get_refresh_token_by_token(
|
||||
self, token: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token_by_token(token)
|
||||
|
||||
@callback
|
||||
def async_create_access_token(self, refresh_token):
|
||||
def async_create_access_token(self,
|
||||
refresh_token: models.RefreshToken) -> str:
|
||||
"""Create a new access token."""
|
||||
# pylint: disable=no-self-use
|
||||
return jwt.encode({
|
||||
|
@ -200,7 +213,8 @@ class AuthManager:
|
|||
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
|
||||
}, refresh_token.jwt_key, algorithm='HS256').decode()
|
||||
|
||||
async def async_validate_access_token(self, token):
|
||||
async def async_validate_access_token(
|
||||
self, token: str) -> Optional[models.RefreshToken]:
|
||||
"""Return if an access token is valid."""
|
||||
try:
|
||||
unverif_claims = jwt.decode(token, verify=False)
|
||||
|
@ -208,7 +222,7 @@ class AuthManager:
|
|||
return None
|
||||
|
||||
refresh_token = await self.async_get_refresh_token(
|
||||
unverif_claims.get('iss'))
|
||||
cast(str, unverif_claims.get('iss')))
|
||||
|
||||
if refresh_token is None:
|
||||
jwt_key = ''
|
||||
|
@ -228,18 +242,22 @@ class AuthManager:
|
|||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
if not refresh_token.user.is_active:
|
||||
if refresh_token is None or not refresh_token.user.is_active:
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
async def _async_create_login_flow(self, handler, *, context, data):
|
||||
async def _async_create_login_flow(
|
||||
self, handler: _ProviderKey, *, context: Optional[Dict],
|
||||
data: Optional[Any]) -> data_entry_flow.FlowHandler:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self._providers[handler]
|
||||
|
||||
return await auth_provider.async_credential_flow(context)
|
||||
|
||||
async def _async_finish_login_flow(self, context, result):
|
||||
async def _async_finish_login_flow(
|
||||
self, context: Optional[Dict], result: Dict[str, Any]) \
|
||||
-> Optional[models.Credentials]:
|
||||
"""Result of a credential login flow."""
|
||||
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return None
|
||||
|
@ -249,13 +267,14 @@ class AuthManager:
|
|||
result['data'])
|
||||
|
||||
@callback
|
||||
def _async_get_auth_provider(self, credentials):
|
||||
def _async_get_auth_provider(
|
||||
self, credentials: models.Credentials) -> Optional[AuthProvider]:
|
||||
"""Helper to get auth provider from a set of credentials."""
|
||||
auth_provider_key = (credentials.auth_provider_type,
|
||||
credentials.auth_provider_id)
|
||||
return self._providers.get(auth_provider_key)
|
||||
|
||||
async def _user_should_be_owner(self):
|
||||
async def _user_should_be_owner(self) -> bool:
|
||||
"""Determine if user should be owner.
|
||||
|
||||
A user should be an owner if it is the first non-system user that is
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
"""Storage for auth models."""
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from typing import Any, Dict, List, Optional # noqa: F401
|
||||
import hmac
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import models
|
||||
|
@ -20,35 +23,41 @@ class AuthStore:
|
|||
called that needs it.
|
||||
"""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users = None
|
||||
self._users = None # type: Optional[Dict[str, models.User]]
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def async_get_users(self):
|
||||
async def async_get_users(self) -> List[models.User]:
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_create_user(self, name, is_owner=None, is_active=None,
|
||||
system_generated=None, credentials=None):
|
||||
async def async_create_user(
|
||||
self, name: Optional[str], is_owner: Optional[bool] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
system_generated: Optional[bool] = None,
|
||||
credentials: Optional[models.Credentials] = None) -> models.User:
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
kwargs = {
|
||||
'name': name
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if is_owner is not None:
|
||||
kwargs['is_owner'] = is_owner
|
||||
|
@ -71,29 +80,39 @@ class AuthStore:
|
|||
await self.async_link_user(new_user, credentials)
|
||||
return new_user
|
||||
|
||||
async def async_link_user(self, user, credentials):
|
||||
async def async_link_user(self, user: models.User,
|
||||
credentials: models.Credentials) -> None:
|
||||
"""Add credentials to an existing user."""
|
||||
user.credentials.append(credentials)
|
||||
await self.async_save()
|
||||
credentials.is_new = False
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
async def async_remove_user(self, user: models.User) -> None:
|
||||
"""Remove a user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_activate_user(self, user):
|
||||
async def async_activate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
user.is_active = True
|
||||
await self.async_save()
|
||||
|
||||
async def async_deactivate_user(self, user):
|
||||
async def async_deactivate_user(self, user: models.User) -> None:
|
||||
"""Activate a user."""
|
||||
user.is_active = False
|
||||
await self.async_save()
|
||||
|
||||
async def async_remove_credentials(self, credentials):
|
||||
async def async_remove_credentials(
|
||||
self, credentials: models.Credentials) -> None:
|
||||
"""Remove credentials."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
found = None
|
||||
|
||||
|
@ -108,17 +127,21 @@ class AuthStore:
|
|||
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id=None):
|
||||
async def async_create_refresh_token(
|
||||
self, user: models.User, client_id: Optional[str] = None) \
|
||||
-> models.RefreshToken:
|
||||
"""Create a new token for a user."""
|
||||
refresh_token = models.RefreshToken(user=user, client_id=client_id)
|
||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||
await self.async_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_get_refresh_token(self, token_id):
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by id."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token_id)
|
||||
|
@ -127,10 +150,12 @@ class AuthStore:
|
|||
|
||||
return None
|
||||
|
||||
async def async_get_refresh_token_by_token(self, token):
|
||||
async def async_get_refresh_token_by_token(
|
||||
self, token: str) -> Optional[models.RefreshToken]:
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
found = None
|
||||
|
||||
|
@ -141,7 +166,7 @@ class AuthStore:
|
|||
|
||||
return found
|
||||
|
||||
async def async_load(self):
|
||||
async def async_load(self) -> None:
|
||||
"""Load the users."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
|
@ -150,7 +175,7 @@ class AuthStore:
|
|||
if self._users is not None:
|
||||
return
|
||||
|
||||
users = OrderedDict()
|
||||
users = OrderedDict() # type: Dict[str, models.User]
|
||||
|
||||
if data is None:
|
||||
self._users = users
|
||||
|
@ -173,11 +198,17 @@ class AuthStore:
|
|||
if 'jwt_key' not in rt_dict:
|
||||
continue
|
||||
|
||||
created_at = dt_util.parse_datetime(rt_dict['created_at'])
|
||||
if created_at is None:
|
||||
getLogger(__name__).error(
|
||||
'Ignoring refresh token %(id)s with invalid created_at '
|
||||
'%(created_at)s for user_id %(user_id)s', rt_dict)
|
||||
continue
|
||||
token = models.RefreshToken(
|
||||
id=rt_dict['id'],
|
||||
user=users[rt_dict['user_id']],
|
||||
client_id=rt_dict['client_id'],
|
||||
created_at=dt_util.parse_datetime(rt_dict['created_at']),
|
||||
created_at=created_at,
|
||||
access_token_expiration=timedelta(
|
||||
seconds=rt_dict['access_token_expiration']),
|
||||
token=rt_dict['token'],
|
||||
|
@ -187,8 +218,12 @@ class AuthStore:
|
|||
|
||||
self._users = users
|
||||
|
||||
async def async_save(self):
|
||||
async def async_save(self) -> None:
|
||||
"""Save users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
assert self._users is not None
|
||||
|
||||
users = [
|
||||
{
|
||||
'id': user.id,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Auth models."""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, NamedTuple, Optional # noqa: F401
|
||||
import uuid
|
||||
|
||||
import attr
|
||||
|
@ -14,17 +15,21 @@ from .util import generate_secret
|
|||
class User:
|
||||
"""A user."""
|
||||
|
||||
name = attr.ib(type=str)
|
||||
name = attr.ib(type=str) # type: Optional[str]
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_owner = attr.ib(type=bool, default=False)
|
||||
is_active = attr.ib(type=bool, default=False)
|
||||
system_generated = attr.ib(type=bool, default=False)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||
credentials = attr.ib(
|
||||
type=list, default=attr.Factory(list), cmp=False
|
||||
) # type: List[Credentials]
|
||||
|
||||
# Tokens associated with a user.
|
||||
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
|
||||
refresh_tokens = attr.ib(
|
||||
type=dict, default=attr.Factory(dict), cmp=False
|
||||
) # type: Dict[str, RefreshToken]
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -32,7 +37,7 @@ class RefreshToken:
|
|||
"""RefreshToken for a user to grant new access tokens."""
|
||||
|
||||
user = attr.ib(type=User)
|
||||
client_id = attr.ib(type=str)
|
||||
client_id = attr.ib(type=str) # type: Optional[str]
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
|
||||
access_token_expiration = attr.ib(type=timedelta,
|
||||
|
@ -48,10 +53,14 @@ class Credentials:
|
|||
"""Credentials for a user on an auth provider."""
|
||||
|
||||
auth_provider_type = attr.ib(type=str)
|
||||
auth_provider_id = attr.ib(type=str)
|
||||
auth_provider_id = attr.ib(type=str) # type: Optional[str]
|
||||
|
||||
# Allow the auth provider to store data to represent their auth.
|
||||
data = attr.ib(type=dict)
|
||||
|
||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||
is_new = attr.ib(type=bool, default=True)
|
||||
|
||||
|
||||
UserMeta = NamedTuple("UserMeta",
|
||||
[('name', Optional[str]), ('is_active', bool)])
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
"""Auth providers for Home Assistant."""
|
||||
import importlib
|
||||
import logging
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import requirements
|
||||
from homeassistant.core import callback
|
||||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
from homeassistant.auth.models import Credentials
|
||||
from homeassistant.auth.auth_store import AuthStore
|
||||
from homeassistant.auth.models import Credentials, UserMeta
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
DATA_REQS = 'auth_prov_reqs_processed'
|
||||
|
@ -25,7 +28,80 @@ AUTH_PROVIDER_SCHEMA = vol.Schema({
|
|||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
||||
async def auth_provider_from_config(hass, store, config):
|
||||
class AuthProvider:
|
||||
"""Provider of user authentication."""
|
||||
|
||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||
|
||||
def __init__(self, hass: HomeAssistant, store: AuthStore,
|
||||
config: Dict[str, Any]) -> None:
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
self.store = store
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]: # pylint: disable=invalid-name
|
||||
"""Return id of the auth provider.
|
||||
|
||||
Optional, can be None.
|
||||
"""
|
||||
return self.config.get(CONF_ID)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Return type of the provider."""
|
||||
return self.config[CONF_TYPE] # type: ignore
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the auth provider."""
|
||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||
|
||||
async def async_credentials(self) -> List[Credentials]:
|
||||
"""Return all credentials of this provider."""
|
||||
users = await self.store.async_get_users()
|
||||
return [
|
||||
credentials
|
||||
for user in users
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == self.type and
|
||||
credentials.auth_provider_id == self.id)
|
||||
]
|
||||
|
||||
@callback
|
||||
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
|
||||
"""Create credentials."""
|
||||
return Credentials(
|
||||
auth_provider_type=self.type,
|
||||
auth_provider_id=self.id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def auth_provider_from_config(
|
||||
hass: HomeAssistant, store: AuthStore,
|
||||
config: Dict[str, Any]) -> Optional[AuthProvider]:
|
||||
"""Initialize an auth provider from a config."""
|
||||
provider_name = config[CONF_TYPE]
|
||||
module = await load_auth_provider_module(hass, provider_name)
|
||||
|
@ -34,16 +110,17 @@ async def auth_provider_from_config(hass, store, config):
|
|||
return None
|
||||
|
||||
try:
|
||||
config = module.CONFIG_SCHEMA(config)
|
||||
config = module.CONFIG_SCHEMA(config) # type: ignore
|
||||
except vol.Invalid as err:
|
||||
_LOGGER.error('Invalid configuration for auth provider %s: %s',
|
||||
provider_name, humanize_error(config, err))
|
||||
return None
|
||||
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config)
|
||||
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore
|
||||
|
||||
|
||||
async def load_auth_provider_module(hass, provider):
|
||||
async def load_auth_provider_module(
|
||||
hass: HomeAssistant, provider: str) -> Optional[types.ModuleType]:
|
||||
"""Load an auth provider."""
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
|
@ -62,82 +139,13 @@ async def load_auth_provider_module(hass, provider):
|
|||
elif provider in processed:
|
||||
return module
|
||||
|
||||
# https://github.com/python/mypy/issues/1424
|
||||
reqs = module.REQUIREMENTS # type: ignore
|
||||
req_success = await requirements.async_process_requirements(
|
||||
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
|
||||
hass, 'auth provider {}'.format(provider), reqs)
|
||||
|
||||
if not req_success:
|
||||
return None
|
||||
|
||||
processed.add(provider)
|
||||
return module
|
||||
|
||||
|
||||
class AuthProvider:
|
||||
"""Provider of user authentication."""
|
||||
|
||||
DEFAULT_TITLE = 'Unnamed auth provider'
|
||||
|
||||
def __init__(self, hass, store, config):
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
self.store = store
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
def id(self): # pylint: disable=invalid-name
|
||||
"""Return id of the auth provider.
|
||||
|
||||
Optional, can be None.
|
||||
"""
|
||||
return self.config.get(CONF_ID)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return type of the provider."""
|
||||
return self.config[CONF_TYPE]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the auth provider."""
|
||||
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
|
||||
|
||||
async def async_credentials(self):
|
||||
"""Return all credentials of this provider."""
|
||||
users = await self.store.async_get_users()
|
||||
return [
|
||||
credentials
|
||||
for user in users
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == self.type and
|
||||
credentials.auth_provider_id == self.id)
|
||||
]
|
||||
|
||||
@callback
|
||||
def async_create_credentials(self, data):
|
||||
"""Create credentials."""
|
||||
return Credentials(
|
||||
auth_provider_type=self.type,
|
||||
auth_provider_id=self.id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
"""Return the data flow for logging in with auth provider."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
"""Get credentials based on the flow result."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
|
||||
Values to populate:
|
||||
- name: string
|
||||
- is_active: boolean
|
||||
"""
|
||||
return {}
|
||||
|
|
|
@ -3,24 +3,25 @@ import base64
|
|||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict # noqa: F401 pylint: disable=unused-import
|
||||
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import callback, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'auth_provider.homeassistant'
|
||||
|
||||
|
||||
def _disallow_id(conf):
|
||||
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Disallow ID in config."""
|
||||
if CONF_ID in conf:
|
||||
raise vol.Invalid(
|
||||
|
@ -46,13 +47,13 @@ class InvalidUser(HomeAssistantError):
|
|||
class Data:
|
||||
"""Hold the user data."""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
self._data = None
|
||||
self._data = None # type: Optional[Dict[str, Any]]
|
||||
|
||||
async def async_load(self):
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
|
@ -65,9 +66,9 @@ class Data:
|
|||
self._data = data
|
||||
|
||||
@property
|
||||
def users(self):
|
||||
def users(self) -> List[Dict[str, str]]:
|
||||
"""Return users."""
|
||||
return self._data['users']
|
||||
return self._data['users'] # type: ignore
|
||||
|
||||
def validate_login(self, username: str, password: str) -> None:
|
||||
"""Validate a username and password.
|
||||
|
@ -79,7 +80,7 @@ class Data:
|
|||
found = None
|
||||
|
||||
# Compare all users to avoid timing attacks.
|
||||
for user in self._data['users']:
|
||||
for user in self.users:
|
||||
if username == user['username']:
|
||||
found = user
|
||||
|
||||
|
@ -94,8 +95,8 @@ class Data:
|
|||
|
||||
def hash_password(self, password: str, for_storage: bool = False) -> bytes:
|
||||
"""Encode a password."""
|
||||
hashed = hashlib.pbkdf2_hmac(
|
||||
'sha512', password.encode(), self._data['salt'].encode(), 100000)
|
||||
salt = self._data['salt'].encode() # type: ignore
|
||||
hashed = hashlib.pbkdf2_hmac('sha512', password.encode(), salt, 100000)
|
||||
if for_storage:
|
||||
hashed = base64.b64encode(hashed)
|
||||
return hashed
|
||||
|
@ -137,7 +138,7 @@ class Data:
|
|||
else:
|
||||
raise InvalidUser
|
||||
|
||||
async def async_save(self):
|
||||
async def async_save(self) -> None:
|
||||
"""Save data."""
|
||||
await self._store.async_save(self._data)
|
||||
|
||||
|
@ -150,7 +151,7 @@ class HassAuthProvider(AuthProvider):
|
|||
|
||||
data = None
|
||||
|
||||
async def async_initialize(self):
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
if self.data is not None:
|
||||
return
|
||||
|
@ -158,19 +159,22 @@ class HassAuthProvider(AuthProvider):
|
|||
self.data = Data(self.hass)
|
||||
await self.data.async_load()
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
async def async_validate_login(self, username: str, password: str):
|
||||
async def async_validate_login(self, username: str, password: str) -> None:
|
||||
"""Helper to validate a username and password."""
|
||||
if self.data is None:
|
||||
await self.async_initialize()
|
||||
assert self.data is not None
|
||||
|
||||
await self.hass.async_add_executor_job(
|
||||
self.data.validate_login, username, password)
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result['username']
|
||||
|
||||
|
@ -183,17 +187,17 @@ class HassAuthProvider(AuthProvider):
|
|||
'username': username
|
||||
})
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials) -> UserMeta:
|
||||
"""Get extra info for this credential."""
|
||||
return {
|
||||
'name': credentials.data['username'],
|
||||
'is_active': True,
|
||||
}
|
||||
return UserMeta(name=credentials.data['username'], is_active=True)
|
||||
|
||||
async def async_will_remove_credentials(self, credentials):
|
||||
async def async_will_remove_credentials(
|
||||
self, credentials: Credentials) -> None:
|
||||
"""When credentials get removed, also remove the auth."""
|
||||
if self.data is None:
|
||||
await self.async_initialize()
|
||||
assert self.data is not None
|
||||
|
||||
try:
|
||||
self.data.async_remove_auth(credentials.data['username'])
|
||||
|
@ -206,11 +210,12 @@ class HassAuthProvider(AuthProvider):
|
|||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider):
|
||||
def __init__(self, auth_provider: HassAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def async_step_init(
|
||||
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Example auth provider."""
|
||||
from collections import OrderedDict
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -9,6 +10,7 @@ from homeassistant import data_entry_flow
|
|||
from homeassistant.core import callback
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
|
||||
USER_SCHEMA = vol.Schema({
|
||||
|
@ -31,12 +33,13 @@ class InvalidAuthError(HomeAssistantError):
|
|||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
@callback
|
||||
def async_validate_login(self, username, password):
|
||||
def async_validate_login(self, username: str, password: str) -> None:
|
||||
"""Helper to validate a username and password."""
|
||||
user = None
|
||||
|
||||
|
@ -56,7 +59,8 @@ class ExampleAuthProvider(AuthProvider):
|
|||
password.encode('utf-8')):
|
||||
raise InvalidAuthError
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result['username']
|
||||
|
||||
|
@ -69,32 +73,32 @@ class ExampleAuthProvider(AuthProvider):
|
|||
'username': username
|
||||
})
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
"""
|
||||
username = credentials.data['username']
|
||||
info = {
|
||||
'is_active': True,
|
||||
}
|
||||
name = None
|
||||
|
||||
for user in self.config['users']:
|
||||
if user['username'] == username:
|
||||
info['name'] = user.get('name')
|
||||
name = user.get('name')
|
||||
break
|
||||
|
||||
return info
|
||||
return UserMeta(name=name, is_active=True)
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider):
|
||||
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def async_step_init(
|
||||
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
@ -111,7 +115,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
data=user_input
|
||||
)
|
||||
|
||||
schema = OrderedDict()
|
||||
schema = OrderedDict() # type: Dict[str, type]
|
||||
schema['username'] = str
|
||||
schema['password'] = str
|
||||
|
||||
|
|
|
@ -5,14 +5,17 @@ It will be removed when auth system production ready
|
|||
"""
|
||||
from collections import OrderedDict
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
|
||||
USER_SCHEMA = vol.Schema({
|
||||
|
@ -36,25 +39,29 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
|
||||
DEFAULT_TITLE = 'Legacy API Password'
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
"""Return a flow to login."""
|
||||
return LoginFlow(self)
|
||||
|
||||
@callback
|
||||
def async_validate_login(self, password):
|
||||
def async_validate_login(self, password: str) -> None:
|
||||
"""Helper to validate a username and password."""
|
||||
if not hasattr(self.hass, 'http'):
|
||||
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
|
||||
|
||||
if not hass_http:
|
||||
raise ValueError('http component is not loaded')
|
||||
|
||||
if self.hass.http.api_password is None:
|
||||
if hass_http.api_password is None:
|
||||
raise ValueError('http component is not configured using'
|
||||
' api_password')
|
||||
|
||||
if not hmac.compare_digest(self.hass.http.api_password.encode('utf-8'),
|
||||
if not hmac.compare_digest(hass_http.api_password.encode('utf-8'),
|
||||
password.encode('utf-8')):
|
||||
raise InvalidAuthError
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
"""Return LEGACY_USER always."""
|
||||
for credential in await self.async_credentials():
|
||||
if credential.data['username'] == LEGACY_USER:
|
||||
|
@ -64,26 +71,25 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
'username': LEGACY_USER
|
||||
})
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials) -> UserMeta:
|
||||
"""
|
||||
Set name as LEGACY_USER always.
|
||||
|
||||
Will be used to populate info when creating a new user.
|
||||
"""
|
||||
return {
|
||||
'name': LEGACY_USER,
|
||||
'is_active': True,
|
||||
}
|
||||
return UserMeta(name=LEGACY_USER, is_active=True)
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider):
|
||||
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def async_step_init(
|
||||
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
@ -100,7 +106,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
data={}
|
||||
)
|
||||
|
||||
schema = OrderedDict()
|
||||
schema = OrderedDict() # type: Dict[str, type]
|
||||
schema['password'] = str
|
||||
|
||||
return self.async_show_form(
|
||||
|
|
|
@ -3,12 +3,16 @@
|
|||
It shows list of users if access from trusted network.
|
||||
Abort login flow if not access from trusted network.
|
||||
"""
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
|
||||
from ..models import Credentials, UserMeta
|
||||
|
||||
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
|
||||
}, extra=vol.PREVENT_EXTRA)
|
||||
|
@ -31,16 +35,20 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
|
||||
DEFAULT_TITLE = 'Trusted Networks'
|
||||
|
||||
async def async_credential_flow(self, context):
|
||||
async def async_credential_flow(
|
||||
self, context: Optional[Dict]) -> 'LoginFlow':
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
users = await self.store.async_get_users()
|
||||
available_users = {user.id: user.name
|
||||
for user in users
|
||||
if not user.system_generated and user.is_active}
|
||||
|
||||
return LoginFlow(self, context.get('ip_address'), available_users)
|
||||
return LoginFlow(self, cast(str, context.get('ip_address')),
|
||||
available_users)
|
||||
|
||||
async def async_get_or_create_credentials(self, flow_result):
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
user_id = flow_result['user']
|
||||
|
||||
|
@ -59,7 +67,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
# We only allow login as exist user
|
||||
raise InvalidUserError
|
||||
|
||||
async def async_user_meta_for_credentials(self, credentials):
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Trusted network auth provider should never create new user.
|
||||
|
@ -67,31 +76,36 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
raise NotImplementedError
|
||||
|
||||
@callback
|
||||
def async_validate_access(self, ip_address):
|
||||
def async_validate_access(self, ip_address: str) -> None:
|
||||
"""Make sure the access from trusted networks.
|
||||
|
||||
Raise InvalidAuthError if not.
|
||||
Raise InvalidAuthError if trusted_networks is not config
|
||||
Raise InvalidAuthError if trusted_networks is not configured.
|
||||
"""
|
||||
if (not hasattr(self.hass, 'http') or
|
||||
not self.hass.http or not self.hass.http.trusted_networks):
|
||||
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
|
||||
|
||||
if not hass_http or not hass_http.trusted_networks:
|
||||
raise InvalidAuthError('trusted_networks is not configured')
|
||||
|
||||
if not any(ip_address in trusted_network for trusted_network
|
||||
in self.hass.http.trusted_networks):
|
||||
in hass_http.trusted_networks):
|
||||
raise InvalidAuthError('Not in trusted_networks')
|
||||
|
||||
|
||||
class LoginFlow(data_entry_flow.FlowHandler):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(self, auth_provider, ip_address, available_users):
|
||||
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
|
||||
ip_address: str, available_users: Dict[str, Optional[str]]) \
|
||||
-> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self._available_users = available_users
|
||||
self._ip_address = ip_address
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None) \
|
||||
-> Dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue