Update trusted networks flow (#16227)
* Update the trusted networks flow * Fix tests * Remove errorspull/16256/head
parent
3e65009ea9
commit
9b01972b41
|
@ -111,31 +111,19 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||||
self, user_input: Optional[Dict[str, str]] = None) \
|
self, user_input: Optional[Dict[str, str]] = None) \
|
||||||
-> Dict[str, Any]:
|
-> Dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
|
||||||
try:
|
try:
|
||||||
cast(TrustedNetworksAuthProvider, self._auth_provider)\
|
cast(TrustedNetworksAuthProvider, self._auth_provider)\
|
||||||
.async_validate_access(self._ip_address)
|
.async_validate_access(self._ip_address)
|
||||||
|
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
errors['base'] = 'invalid_auth'
|
return self.async_abort(
|
||||||
return self.async_show_form(
|
reason='not_whitelisted'
|
||||||
step_id='init',
|
|
||||||
data_schema=None,
|
|
||||||
errors=errors,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
user_id = user_input['user']
|
|
||||||
if user_id not in self._available_users:
|
|
||||||
errors['base'] = 'invalid_auth'
|
|
||||||
|
|
||||||
if not errors:
|
|
||||||
return await self.async_finish(user_input)
|
return await self.async_finish(user_input)
|
||||||
|
|
||||||
schema = {'user': vol.In(self._available_users)}
|
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id='init',
|
step_id='init',
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema({'user': vol.In(self._available_users)}),
|
||||||
errors=errors,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -74,16 +74,16 @@ async def test_login_flow(manager, provider):
|
||||||
# trusted network didn't loaded
|
# trusted network didn't loaded
|
||||||
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||||
step = await flow.async_step_init()
|
step = await flow.async_step_init()
|
||||||
assert step['step_id'] == 'init'
|
assert step['type'] == 'abort'
|
||||||
assert step['errors']['base'] == 'invalid_auth'
|
assert step['reason'] == 'not_whitelisted'
|
||||||
|
|
||||||
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
provider.hass.http = Mock(trusted_networks=['192.168.0.1'])
|
||||||
|
|
||||||
# not from trusted network
|
# not from trusted network
|
||||||
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '127.0.0.1'})
|
||||||
step = await flow.async_step_init()
|
step = await flow.async_step_init()
|
||||||
assert step['step_id'] == 'init'
|
assert step['type'] == 'abort'
|
||||||
assert step['errors']['base'] == 'invalid_auth'
|
assert step['reason'] == 'not_whitelisted'
|
||||||
|
|
||||||
# from trusted network, list users
|
# from trusted network, list users
|
||||||
flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
|
flow = await provider.async_login_flow({'ip_address': '192.168.0.1'})
|
||||||
|
@ -95,11 +95,6 @@ async def test_login_flow(manager, provider):
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
assert schema({'user': 'invalid-user'})
|
assert schema({'user': 'invalid-user'})
|
||||||
|
|
||||||
# login with invalid user
|
|
||||||
step = await flow.async_step_init({'user': 'invalid-user'})
|
|
||||||
assert step['step_id'] == 'init'
|
|
||||||
assert step['errors']['base'] == 'invalid_auth'
|
|
||||||
|
|
||||||
# login with valid user
|
# login with valid user
|
||||||
step = await flow.async_step_init({'user': user.id})
|
step = await flow.async_step_init({'user': user.id})
|
||||||
assert step['type'] == 'create_entry'
|
assert step['type'] == 'create_entry'
|
||||||
|
|
Loading…
Reference in New Issue