Allow auth providers to influence is_active ()

* Allow auth providers to influence is_active

* Fix auth script test
pull/15560/head
Paulus Schoutsen 2018-07-19 22:10:36 +02:00 committed by GitHub
parent a42288d056
commit 2fcacbff23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 82 additions and 23 deletions

View File

@ -124,6 +124,7 @@ class AuthManager:
return await self._store.async_create_user(
credentials=credentials,
name=info.get('name'),
is_active=info.get('is_active', False)
)
async def async_link_user(self, user, credentials):

View File

@ -135,5 +135,9 @@ class AuthProvider:
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
Values to populate:
- name: string
- is_active: boolean
"""
return {}

View File

@ -184,7 +184,8 @@ class HassAuthProvider(AuthProvider):
async def async_user_meta_for_credentials(self, credentials):
"""Get extra info for this credential."""
return {
'name': credentials.data['username']
'name': credentials.data['username'],
'is_active': True,
}
async def async_will_remove_credentials(self, credentials):

View File

@ -75,14 +75,16 @@ class ExampleAuthProvider(AuthProvider):
Will be used to populate info when creating a new user.
"""
username = credentials.data['username']
info = {
'is_active': True,
}
for user in self.config['users']:
if user['username'] == username:
return {
'name': user.get('name')
}
info['name'] = user.get('name')
break
return {}
return info
class LoginFlow(data_entry_flow.FlowHandler):

View File

@ -70,7 +70,10 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
Will be used to populate info when creating a new user.
"""
return {'name': LEGACY_USER}
return {
'name': LEGACY_USER,
'is_active': True,
}
class LoginFlow(data_entry_flow.FlowHandler):

View File

@ -81,16 +81,9 @@ async def add_user(hass, provider, args):
print("Username already exists!")
return
credentials = await provider.async_get_or_create_credentials({
'username': args.username
})
user = await hass.auth.async_create_user(args.username)
await hass.auth.async_link_user(user, credentials)
# Save username/password
await provider.data.async_save()
print("User created")
print("Auth created")
async def validate_login(hass, provider, args):

View File

@ -4,6 +4,7 @@ from unittest.mock import Mock
import pytest
from homeassistant import data_entry_flow
from homeassistant.auth import auth_manager_from_config
from homeassistant.auth.providers import (
auth_provider_from_config, homeassistant as hass_auth)
@ -112,3 +113,20 @@ async def test_not_allow_set_id():
'id': 'invalid',
})
assert provider is None
async def test_new_users_populate_values(hass, data):
"""Test that we populate data for new users."""
data.add_auth('hello', 'test-pass')
await data.async_save()
manager = await auth_manager_from_config(hass, [{
'type': 'homeassistant'
}])
provider = manager.auth_providers[0]
credentials = await provider.async_get_or_create_credentials({
'username': 'hello'
})
user = await manager.async_get_or_create_user(credentials)
assert user.name == 'hello'
assert user.is_active

View File

@ -4,7 +4,7 @@ import uuid
import pytest
from homeassistant.auth import auth_store, models as auth_models
from homeassistant.auth import auth_store, models as auth_models, AuthManager
from homeassistant.auth.providers import insecure_example
from tests.common import mock_coro
@ -23,6 +23,7 @@ def provider(hass, store):
'type': 'insecure_example',
'users': [
{
'name': 'Test Name',
'username': 'user-test',
'password': 'password-test',
},
@ -34,7 +35,15 @@ def provider(hass, store):
})
async def test_create_new_credential(provider):
@pytest.fixture
def manager(hass, store, provider):
"""Mock manager."""
return AuthManager(hass, store, {
(provider.type, provider.id): provider
})
async def test_create_new_credential(manager, provider):
"""Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({
'username': 'user-test',
@ -42,6 +51,10 @@ async def test_create_new_credential(provider):
})
assert credentials.is_new is True
user = await manager.async_get_or_create_user(credentials)
assert user.name == 'Test Name'
assert user.is_active
async def test_match_existing_credentials(store, provider):
"""See if we match existing users."""

View File

@ -30,12 +30,16 @@ def manager(hass, store, provider):
})
async def test_create_new_credential(provider):
async def test_create_new_credential(manager, provider):
"""Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({})
assert credentials.data["username"] is legacy_api_password.LEGACY_USER
assert credentials.is_new is True
user = await manager.async_get_or_create_user(credentials)
assert user.name == legacy_api_password.LEGACY_USER
assert user.is_active
async def test_only_one_credentials(manager, provider):
"""Call create twice will return same credential."""

View File

@ -40,11 +40,31 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
'code': code
})
# User is not active
assert resp.status == 403
data = await resp.json()
assert data['error'] == 'access_denied'
assert data['error_description'] == 'User is not active'
assert resp.status == 200
tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
# Use refresh token to get more tokens.
resp = await client.post('/auth/token', data={
'client_id': CLIENT_ID,
'grant_type': 'refresh_token',
'refresh_token': tokens['refresh_token']
})
assert resp.status == 200
tokens = await resp.json()
assert 'refresh_token' not in tokens
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
# Test using access token to hit API.
resp = await client.get('/api/')
assert resp.status == 401
resp = await client.get('/api/', headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 200
def test_credential_store_expiration():

View File

@ -47,7 +47,7 @@ async def test_add_user(hass, provider, capsys, hass_storage):
assert len(hass_storage[hass_auth.STORAGE_KEY]['data']['users']) == 1
captured = capsys.readouterr()
assert captured.out == 'User created\n'
assert captured.out == 'Auth created\n'
assert len(data.users) == 1
data.validate_login('paulus', 'test-pass')