Add type hints to homeassistant.auth ()

* Always load users in auth store before use

* Use namedtuple instead of dict for user meta

* Ignore auth store tokens with invalid created_at

* Add type hints to homeassistant.auth
pull/16021/head
Ville Skyttä 2018-08-16 23:25:41 +03:00 committed by Paulus Schoutsen
parent e9e5bce10c
commit 649f17fe47
9 changed files with 300 additions and 200 deletions

View File

@ -2,7 +2,7 @@
import asyncio
import logging
from collections import OrderedDict
from typing import List, Awaitable
from typing import Any, Dict, List, Optional, Tuple, cast
import jwt
@ -10,15 +10,17 @@ from homeassistant import data_entry_flow
from homeassistant.core import callback, HomeAssistant
from homeassistant.util import dt as dt_util
from . import auth_store
from .providers import auth_provider_from_config
from . import auth_store, models
from .providers import auth_provider_from_config, AuthProvider
_LOGGER = logging.getLogger(__name__)
_ProviderKey = Tuple[str, Optional[str]]
_ProviderDict = Dict[_ProviderKey, AuthProvider]
async def auth_manager_from_config(
hass: HomeAssistant,
provider_configs: List[dict]) -> Awaitable['AuthManager']:
provider_configs: List[Dict[str, Any]]) -> 'AuthManager':
"""Initialize an auth manager from config."""
store = auth_store.AuthStore(hass)
if provider_configs:
@ -26,9 +28,9 @@ async def auth_manager_from_config(
*[auth_provider_from_config(hass, store, config)
for config in provider_configs])
else:
providers = []
providers = ()
# So returned auth providers are in same order as config
provider_hash = OrderedDict()
provider_hash = OrderedDict() # type: _ProviderDict
for provider in providers:
if provider is None:
continue
@ -49,7 +51,8 @@ async def auth_manager_from_config(
class AuthManager:
"""Manage the authentication for Home Assistant."""
def __init__(self, hass, store, providers):
def __init__(self, hass: HomeAssistant, store: auth_store.AuthStore,
providers: _ProviderDict) -> None:
"""Initialize the auth manager."""
self._store = store
self._providers = providers
@ -58,12 +61,12 @@ class AuthManager:
self._async_finish_login_flow)
@property
def active(self):
def active(self) -> bool:
"""Return if any auth providers are registered."""
return bool(self._providers)
@property
def support_legacy(self):
def support_legacy(self) -> bool:
"""
Return if legacy_api_password auth providers are registered.
@ -75,19 +78,19 @@ class AuthManager:
return False
@property
def auth_providers(self):
def auth_providers(self) -> List[AuthProvider]:
"""Return a list of available auth providers."""
return list(self._providers.values())
async def async_get_users(self):
async def async_get_users(self) -> List[models.User]:
"""Retrieve all users."""
return await self._store.async_get_users()
async def async_get_user(self, user_id):
async def async_get_user(self, user_id: str) -> Optional[models.User]:
"""Retrieve a user."""
return await self._store.async_get_user(user_id)
async def async_create_system_user(self, name):
async def async_create_system_user(self, name: str) -> models.User:
"""Create a system user."""
return await self._store.async_create_user(
name=name,
@ -95,19 +98,20 @@ class AuthManager:
is_active=True,
)
async def async_create_user(self, name):
async def async_create_user(self, name: str) -> models.User:
"""Create a user."""
kwargs = {
'name': name,
'is_active': True,
}
} # type: Dict[str, Any]
if await self._user_should_be_owner():
kwargs['is_owner'] = True
return await self._store.async_create_user(**kwargs)
async def async_get_or_create_user(self, credentials):
async def async_get_or_create_user(self, credentials: models.Credentials) \
-> models.User:
"""Get or create a user."""
if not credentials.is_new:
for user in await self._store.async_get_users():
@ -127,15 +131,16 @@ class AuthManager:
return await self._store.async_create_user(
credentials=credentials,
name=info.get('name'),
is_active=info.get('is_active', False)
name=info.name,
is_active=info.is_active,
)
async def async_link_user(self, user, credentials):
async def async_link_user(self, user: models.User,
credentials: models.Credentials) -> None:
"""Link credentials to an existing user."""
await self._store.async_link_user(user, credentials)
async def async_remove_user(self, user):
async def async_remove_user(self, user: models.User) -> None:
"""Remove a user."""
tasks = [
self.async_remove_credentials(credentials)
@ -147,27 +152,32 @@ class AuthManager:
await self._store.async_remove_user(user)
async def async_activate_user(self, user):
async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
await self._store.async_activate_user(user)
async def async_deactivate_user(self, user):
async def async_deactivate_user(self, user: models.User) -> None:
"""Deactivate a user."""
if user.is_owner:
raise ValueError('Unable to deactive the owner')
await self._store.async_deactivate_user(user)
async def async_remove_credentials(self, credentials):
async def async_remove_credentials(
self, credentials: models.Credentials) -> None:
"""Remove credentials."""
provider = self._async_get_auth_provider(credentials)
if (provider is not None and
hasattr(provider, 'async_will_remove_credentials')):
await provider.async_will_remove_credentials(credentials)
# https://github.com/python/mypy/issues/1424
await provider.async_will_remove_credentials( # type: ignore
credentials)
await self._store.async_remove_credentials(credentials)
async def async_create_refresh_token(self, user, client_id=None):
async def async_create_refresh_token(self, user: models.User,
client_id: Optional[str] = None) \
-> models.RefreshToken:
"""Create a new refresh token for a user."""
if not user.is_active:
raise ValueError('User is not active')
@ -182,16 +192,19 @@ class AuthManager:
return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token_id):
async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]:
"""Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token(self, token):
async def async_get_refresh_token_by_token(
self, token: str) -> Optional[models.RefreshToken]:
"""Get refresh token by token."""
return await self._store.async_get_refresh_token_by_token(token)
@callback
def async_create_access_token(self, refresh_token):
def async_create_access_token(self,
refresh_token: models.RefreshToken) -> str:
"""Create a new access token."""
# pylint: disable=no-self-use
return jwt.encode({
@ -200,7 +213,8 @@ class AuthManager:
'exp': dt_util.utcnow() + refresh_token.access_token_expiration,
}, refresh_token.jwt_key, algorithm='HS256').decode()
async def async_validate_access_token(self, token):
async def async_validate_access_token(
self, token: str) -> Optional[models.RefreshToken]:
"""Return if an access token is valid."""
try:
unverif_claims = jwt.decode(token, verify=False)
@ -208,7 +222,7 @@ class AuthManager:
return None
refresh_token = await self.async_get_refresh_token(
unverif_claims.get('iss'))
cast(str, unverif_claims.get('iss')))
if refresh_token is None:
jwt_key = ''
@ -228,18 +242,22 @@ class AuthManager:
except jwt.InvalidTokenError:
return None
if not refresh_token.user.is_active:
if refresh_token is None or not refresh_token.user.is_active:
return None
return refresh_token
async def _async_create_login_flow(self, handler, *, context, data):
async def _async_create_login_flow(
self, handler: _ProviderKey, *, context: Optional[Dict],
data: Optional[Any]) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self._providers[handler]
return await auth_provider.async_credential_flow(context)
async def _async_finish_login_flow(self, context, result):
async def _async_finish_login_flow(
self, context: Optional[Dict], result: Dict[str, Any]) \
-> Optional[models.Credentials]:
"""Result of a credential login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
@ -249,13 +267,14 @@ class AuthManager:
result['data'])
@callback
def _async_get_auth_provider(self, credentials):
def _async_get_auth_provider(
self, credentials: models.Credentials) -> Optional[AuthProvider]:
"""Helper to get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type,
credentials.auth_provider_id)
return self._providers.get(auth_provider_key)
async def _user_should_be_owner(self):
async def _user_should_be_owner(self) -> bool:
"""Determine if user should be owner.
A user should be an owner if it is the first non-system user that is

View File

@ -1,8 +1,11 @@
"""Storage for auth models."""
from collections import OrderedDict
from datetime import timedelta
from logging import getLogger
from typing import Any, Dict, List, Optional # noqa: F401
import hmac
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
from . import models
@ -20,35 +23,41 @@ class AuthStore:
called that needs it.
"""
def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the auth store."""
self.hass = hass
self._users = None
self._users = None # type: Optional[Dict[str, models.User]]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def async_get_users(self):
async def async_get_users(self) -> List[models.User]:
"""Retrieve all users."""
if self._users is None:
await self.async_load()
assert self._users is not None
return list(self._users.values())
async def async_get_user(self, user_id):
async def async_get_user(self, user_id: str) -> Optional[models.User]:
"""Retrieve a user by id."""
if self._users is None:
await self.async_load()
assert self._users is not None
return self._users.get(user_id)
async def async_create_user(self, name, is_owner=None, is_active=None,
system_generated=None, credentials=None):
async def async_create_user(
self, name: Optional[str], is_owner: Optional[bool] = None,
is_active: Optional[bool] = None,
system_generated: Optional[bool] = None,
credentials: Optional[models.Credentials] = None) -> models.User:
"""Create a new user."""
if self._users is None:
await self.async_load()
assert self._users is not None
kwargs = {
'name': name
}
} # type: Dict[str, Any]
if is_owner is not None:
kwargs['is_owner'] = is_owner
@ -71,29 +80,39 @@ class AuthStore:
await self.async_link_user(new_user, credentials)
return new_user
async def async_link_user(self, user, credentials):
async def async_link_user(self, user: models.User,
credentials: models.Credentials) -> None:
"""Add credentials to an existing user."""
user.credentials.append(credentials)
await self.async_save()
credentials.is_new = False
async def async_remove_user(self, user):
async def async_remove_user(self, user: models.User) -> None:
"""Remove a user."""
if self._users is None:
await self.async_load()
assert self._users is not None
self._users.pop(user.id)
await self.async_save()
async def async_activate_user(self, user):
async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = True
await self.async_save()
async def async_deactivate_user(self, user):
async def async_deactivate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = False
await self.async_save()
async def async_remove_credentials(self, credentials):
async def async_remove_credentials(
self, credentials: models.Credentials) -> None:
"""Remove credentials."""
if self._users is None:
await self.async_load()
assert self._users is not None
for user in self._users.values():
found = None
@ -108,17 +127,21 @@ class AuthStore:
await self.async_save()
async def async_create_refresh_token(self, user, client_id=None):
async def async_create_refresh_token(
self, user: models.User, client_id: Optional[str] = None) \
-> models.RefreshToken:
"""Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.id] = refresh_token
await self.async_save()
return refresh_token
async def async_get_refresh_token(self, token_id):
async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]:
"""Get refresh token by id."""
if self._users is None:
await self.async_load()
assert self._users is not None
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id)
@ -127,10 +150,12 @@ class AuthStore:
return None
async def async_get_refresh_token_by_token(self, token):
async def async_get_refresh_token_by_token(
self, token: str) -> Optional[models.RefreshToken]:
"""Get refresh token by token."""
if self._users is None:
await self.async_load()
assert self._users is not None
found = None
@ -141,7 +166,7 @@ class AuthStore:
return found
async def async_load(self):
async def async_load(self) -> None:
"""Load the users."""
data = await self._store.async_load()
@ -150,7 +175,7 @@ class AuthStore:
if self._users is not None:
return
users = OrderedDict()
users = OrderedDict() # type: Dict[str, models.User]
if data is None:
self._users = users
@ -173,11 +198,17 @@ class AuthStore:
if 'jwt_key' not in rt_dict:
continue
created_at = dt_util.parse_datetime(rt_dict['created_at'])
if created_at is None:
getLogger(__name__).error(
'Ignoring refresh token %(id)s with invalid created_at '
'%(created_at)s for user_id %(user_id)s', rt_dict)
continue
token = models.RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
client_id=rt_dict['client_id'],
created_at=dt_util.parse_datetime(rt_dict['created_at']),
created_at=created_at,
access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'],
@ -187,8 +218,12 @@ class AuthStore:
self._users = users
async def async_save(self):
async def async_save(self) -> None:
"""Save users."""
if self._users is None:
await self.async_load()
assert self._users is not None
users = [
{
'id': user.id,

View File

@ -1,5 +1,6 @@
"""Auth models."""
from datetime import datetime, timedelta
from typing import Dict, List, NamedTuple, Optional # noqa: F401
import uuid
import attr
@ -14,17 +15,21 @@ from .util import generate_secret
class User:
"""A user."""
name = attr.ib(type=str)
name = attr.ib(type=str) # type: Optional[str]
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
system_generated = attr.ib(type=bool, default=False)
# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
credentials = attr.ib(
type=list, default=attr.Factory(list), cmp=False
) # type: List[Credentials]
# Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
refresh_tokens = attr.ib(
type=dict, default=attr.Factory(dict), cmp=False
) # type: Dict[str, RefreshToken]
@attr.s(slots=True)
@ -32,7 +37,7 @@ class RefreshToken:
"""RefreshToken for a user to grant new access tokens."""
user = attr.ib(type=User)
client_id = attr.ib(type=str)
client_id = attr.ib(type=str) # type: Optional[str]
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
access_token_expiration = attr.ib(type=timedelta,
@ -48,10 +53,14 @@ class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type = attr.ib(type=str)
auth_provider_id = attr.ib(type=str)
auth_provider_id = attr.ib(type=str) # type: Optional[str]
# Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_new = attr.ib(type=bool, default=True)
UserMeta = NamedTuple("UserMeta",
[('name', Optional[str]), ('is_active', bool)])

View File

@ -1,16 +1,19 @@
"""Auth providers for Home Assistant."""
import importlib
import logging
import types
from typing import Any, Dict, List, Optional
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import requirements
from homeassistant.core import callback
from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback, HomeAssistant
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util.decorator import Registry
from homeassistant.auth.models import Credentials
from homeassistant.auth.auth_store import AuthStore
from homeassistant.auth.models import Credentials, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed'
@ -25,7 +28,80 @@ AUTH_PROVIDER_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA)
async def auth_provider_from_config(hass, store, config):
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
def __init__(self, hass: HomeAssistant, store: AuthStore,
config: Dict[str, Any]) -> None:
"""Initialize an auth provider."""
self.hass = hass
self.store = store
self.config = config
@property
def id(self) -> Optional[str]: # pylint: disable=invalid-name
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self) -> str:
"""Return type of the provider."""
return self.config[CONF_TYPE] # type: ignore
@property
def name(self) -> str:
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
async def async_credentials(self) -> List[Credentials]:
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]
@callback
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_credential_flow(
self, context: Optional[Dict]) -> data_entry_flow.FlowHandler:
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
raise NotImplementedError
async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore,
config: Dict[str, Any]) -> Optional[AuthProvider]:
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
@ -34,16 +110,17 @@ async def auth_provider_from_config(hass, store, config):
return None
try:
config = module.CONFIG_SCHEMA(config)
config = module.CONFIG_SCHEMA(config) # type: ignore
except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s',
provider_name, humanize_error(config, err))
return None
return AUTH_PROVIDERS[provider_name](hass, store, config)
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore
async def load_auth_provider_module(hass, provider):
async def load_auth_provider_module(
hass: HomeAssistant, provider: str) -> Optional[types.ModuleType]:
"""Load an auth provider."""
try:
module = importlib.import_module(
@ -62,82 +139,13 @@ async def load_auth_provider_module(hass, provider):
elif provider in processed:
return module
# https://github.com/python/mypy/issues/1424
reqs = module.REQUIREMENTS # type: ignore
req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
hass, 'auth provider {}'.format(provider), reqs)
if not req_success:
return None
processed.add(provider)
return module
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
def __init__(self, hass, store, config):
"""Initialize an auth provider."""
self.hass = hass
self.store = store
self.config = config
@property
def id(self): # pylint: disable=invalid-name
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self):
"""Return type of the provider."""
return self.config[CONF_TYPE]
@property
def name(self):
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
async def async_credentials(self):
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]
@callback
def async_create_credentials(self, data):
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_credential_flow(self, context):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(self, credentials):
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
Values to populate:
- name: string
- is_active: boolean
"""
return {}

View File

@ -3,24 +3,25 @@ import base64
from collections import OrderedDict
import hashlib
import hmac
from typing import Dict # noqa: F401 pylint: disable=unused-import
from typing import Any, Dict, List, Optional # noqa: F401,E501 pylint: disable=unused-import
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.const import CONF_ID
from homeassistant.core import callback
from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.auth.util import generate_secret
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from ..models import Credentials, UserMeta
STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant'
def _disallow_id(conf):
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
"""Disallow ID in config."""
if CONF_ID in conf:
raise vol.Invalid(
@ -46,13 +47,13 @@ class InvalidUser(HomeAssistantError):
class Data:
"""Hold the user data."""
def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the user data store."""
self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
self._data = None
self._data = None # type: Optional[Dict[str, Any]]
async def async_load(self):
async def async_load(self) -> None:
"""Load stored data."""
data = await self._store.async_load()
@ -65,9 +66,9 @@ class Data:
self._data = data
@property
def users(self):
def users(self) -> List[Dict[str, str]]:
"""Return users."""
return self._data['users']
return self._data['users'] # type: ignore
def validate_login(self, username: str, password: str) -> None:
"""Validate a username and password.
@ -79,7 +80,7 @@ class Data:
found = None
# Compare all users to avoid timing attacks.
for user in self._data['users']:
for user in self.users:
if username == user['username']:
found = user
@ -94,8 +95,8 @@ class Data:
def hash_password(self, password: str, for_storage: bool = False) -> bytes:
"""Encode a password."""
hashed = hashlib.pbkdf2_hmac(
'sha512', password.encode(), self._data['salt'].encode(), 100000)
salt = self._data['salt'].encode() # type: ignore
hashed = hashlib.pbkdf2_hmac('sha512', password.encode(), salt, 100000)
if for_storage:
hashed = base64.b64encode(hashed)
return hashed
@ -137,7 +138,7 @@ class Data:
else:
raise InvalidUser
async def async_save(self):
async def async_save(self) -> None:
"""Save data."""
await self._store.async_save(self._data)
@ -150,7 +151,7 @@ class HassAuthProvider(AuthProvider):
data = None
async def async_initialize(self):
async def async_initialize(self) -> None:
"""Initialize the auth provider."""
if self.data is not None:
return
@ -158,19 +159,22 @@ class HassAuthProvider(AuthProvider):
self.data = Data(self.hass)
await self.data.async_load()
async def async_credential_flow(self, context):
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login."""
return LoginFlow(self)
async def async_validate_login(self, username: str, password: str):
async def async_validate_login(self, username: str, password: str) -> None:
"""Helper to validate a username and password."""
if self.data is None:
await self.async_initialize()
assert self.data is not None
await self.hass.async_add_executor_job(
self.data.validate_login, username, password)
async def async_get_or_create_credentials(self, flow_result):
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result['username']
@ -183,17 +187,17 @@ class HassAuthProvider(AuthProvider):
'username': username
})
async def async_user_meta_for_credentials(self, credentials):
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
"""Get extra info for this credential."""
return {
'name': credentials.data['username'],
'is_active': True,
}
return UserMeta(name=credentials.data['username'], is_active=True)
async def async_will_remove_credentials(self, credentials):
async def async_will_remove_credentials(
self, credentials: Credentials) -> None:
"""When credentials get removed, also remove the auth."""
if self.data is None:
await self.async_initialize()
assert self.data is not None
try:
self.data.async_remove_auth(credentials.data['username'])
@ -206,11 +210,12 @@ class HassAuthProvider(AuthProvider):
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider):
def __init__(self, auth_provider: HassAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(self, user_input=None):
async def async_step_init(
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
"""Handle the step of the form."""
errors = {}

View File

@ -1,6 +1,7 @@
"""Example auth provider."""
from collections import OrderedDict
import hmac
from typing import Any, Dict, Optional
import voluptuous as vol
@ -9,6 +10,7 @@ from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from ..models import Credentials, UserMeta
USER_SCHEMA = vol.Schema({
@ -31,12 +33,13 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow(self, context):
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login."""
return LoginFlow(self)
@callback
def async_validate_login(self, username, password):
def async_validate_login(self, username: str, password: str) -> None:
"""Helper to validate a username and password."""
user = None
@ -56,7 +59,8 @@ class ExampleAuthProvider(AuthProvider):
password.encode('utf-8')):
raise InvalidAuthError
async def async_get_or_create_credentials(self, flow_result):
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result['username']
@ -69,32 +73,32 @@ class ExampleAuthProvider(AuthProvider):
'username': username
})
async def async_user_meta_for_credentials(self, credentials):
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
username = credentials.data['username']
info = {
'is_active': True,
}
name = None
for user in self.config['users']:
if user['username'] == username:
info['name'] = user.get('name')
name = user.get('name')
break
return info
return UserMeta(name=name, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider):
def __init__(self, auth_provider: ExampleAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(self, user_input=None):
async def async_step_init(
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
"""Handle the step of the form."""
errors = {}
@ -111,7 +115,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
data=user_input
)
schema = OrderedDict()
schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str
schema['password'] = str

View File

@ -5,14 +5,17 @@ It will be removed when auth system production ready
"""
from collections import OrderedDict
import hmac
from typing import Any, Dict, Optional
import voluptuous as vol
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.exceptions import HomeAssistantError
from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from ..models import Credentials, UserMeta
USER_SCHEMA = vol.Schema({
@ -36,25 +39,29 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Legacy API Password'
async def async_credential_flow(self, context):
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login."""
return LoginFlow(self)
@callback
def async_validate_login(self, password):
def async_validate_login(self, password: str) -> None:
"""Helper to validate a username and password."""
if not hasattr(self.hass, 'http'):
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
if not hass_http:
raise ValueError('http component is not loaded')
if self.hass.http.api_password is None:
if hass_http.api_password is None:
raise ValueError('http component is not configured using'
' api_password')
if not hmac.compare_digest(self.hass.http.api_password.encode('utf-8'),
if not hmac.compare_digest(hass_http.api_password.encode('utf-8'),
password.encode('utf-8')):
raise InvalidAuthError
async def async_get_or_create_credentials(self, flow_result):
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
"""Return LEGACY_USER always."""
for credential in await self.async_credentials():
if credential.data['username'] == LEGACY_USER:
@ -64,26 +71,25 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
'username': LEGACY_USER
})
async def async_user_meta_for_credentials(self, credentials):
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
"""
Set name as LEGACY_USER always.
Will be used to populate info when creating a new user.
"""
return {
'name': LEGACY_USER,
'is_active': True,
}
return UserMeta(name=LEGACY_USER, is_active=True)
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider):
def __init__(self, auth_provider: LegacyApiPasswordAuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(self, user_input=None):
async def async_step_init(
self, user_input: Dict[str, str] = None) -> Dict[str, Any]:
"""Handle the step of the form."""
errors = {}
@ -100,7 +106,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
data={}
)
schema = OrderedDict()
schema = OrderedDict() # type: Dict[str, type]
schema['password'] = str
return self.async_show_form(

View File

@ -3,12 +3,16 @@
It shows list of users if access from trusted network.
Abort login flow if not access from trusted network.
"""
from typing import Any, Dict, Optional, cast
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.http import HomeAssistantHTTP # noqa: F401
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
from ..models import Credentials, UserMeta
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA)
@ -31,16 +35,20 @@ class TrustedNetworksAuthProvider(AuthProvider):
DEFAULT_TITLE = 'Trusted Networks'
async def async_credential_flow(self, context):
async def async_credential_flow(
self, context: Optional[Dict]) -> 'LoginFlow':
"""Return a flow to login."""
assert context is not None
users = await self.store.async_get_users()
available_users = {user.id: user.name
for user in users
if not user.system_generated and user.is_active}
return LoginFlow(self, context.get('ip_address'), available_users)
return LoginFlow(self, cast(str, context.get('ip_address')),
available_users)
async def async_get_or_create_credentials(self, flow_result):
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials:
"""Get credentials based on the flow result."""
user_id = flow_result['user']
@ -59,7 +67,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
# We only allow login as exist user
raise InvalidUserError
async def async_user_meta_for_credentials(self, credentials):
async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta:
"""Return extra user metadata for credentials.
Trusted network auth provider should never create new user.
@ -67,31 +76,36 @@ class TrustedNetworksAuthProvider(AuthProvider):
raise NotImplementedError
@callback
def async_validate_access(self, ip_address):
def async_validate_access(self, ip_address: str) -> None:
"""Make sure the access from trusted networks.
Raise InvalidAuthError if not.
Raise InvalidAuthError if trusted_networks is not config
Raise InvalidAuthError if trusted_networks is not configured.
"""
if (not hasattr(self.hass, 'http') or
not self.hass.http or not self.hass.http.trusted_networks):
hass_http = getattr(self.hass, 'http', None) # type: HomeAssistantHTTP
if not hass_http or not hass_http.trusted_networks:
raise InvalidAuthError('trusted_networks is not configured')
if not any(ip_address in trusted_network for trusted_network
in self.hass.http.trusted_networks):
in hass_http.trusted_networks):
raise InvalidAuthError('Not in trusted_networks')
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider, ip_address, available_users):
def __init__(self, auth_provider: TrustedNetworksAuthProvider,
ip_address: str, available_users: Dict[str, Optional[str]]) \
-> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._available_users = available_users
self._ip_address = ip_address
async def async_step_init(self, user_input=None):
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \
-> Dict[str, Any]:
"""Handle the step of the form."""
errors = {}
try:

View File

@ -58,4 +58,4 @@ whitelist_externals=/bin/bash
deps =
-r{toxinidir}/requirements_test.txt
commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/util/'
/bin/bash -c 'mypy homeassistant/*.py homeassistant/auth/ homeassistant/util/'