diff --git a/homeassistant/auth/providers/trusted_networks.py b/homeassistant/auth/providers/trusted_networks.py index 37e032e58d7..8a7e1d67c6d 100644 --- a/homeassistant/auth/providers/trusted_networks.py +++ b/homeassistant/auth/providers/trusted_networks.py @@ -111,31 +111,19 @@ class TrustedNetworksLoginFlow(LoginFlow): self, user_input: Optional[Dict[str, str]] = None) \ -> Dict[str, Any]: """Handle the step of the form.""" - errors = {} try: cast(TrustedNetworksAuthProvider, self._auth_provider)\ .async_validate_access(self._ip_address) except InvalidAuthError: - errors['base'] = 'invalid_auth' - return self.async_show_form( - step_id='init', - data_schema=None, - errors=errors, + return self.async_abort( + reason='not_whitelisted' ) if user_input is not None: - user_id = user_input['user'] - if user_id not in self._available_users: - errors['base'] = 'invalid_auth' - - if not errors: - return await self.async_finish(user_input) - - schema = {'user': vol.In(self._available_users)} + return await self.async_finish(user_input) return self.async_show_form( step_id='init', - data_schema=vol.Schema(schema), - errors=errors, + data_schema=vol.Schema({'user': vol.In(self._available_users)}), ) diff --git a/tests/auth/providers/test_trusted_networks.py b/tests/auth/providers/test_trusted_networks.py index 4839c72a86a..0ca302f8273 100644 --- a/tests/auth/providers/test_trusted_networks.py +++ b/tests/auth/providers/test_trusted_networks.py @@ -74,16 +74,16 @@ async def test_login_flow(manager, provider): # trusted network didn't loaded flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) step = await flow.async_step_init() - assert step['step_id'] == 'init' - assert step['errors']['base'] == 'invalid_auth' + assert step['type'] == 'abort' + assert step['reason'] == 'not_whitelisted' provider.hass.http = Mock(trusted_networks=['192.168.0.1']) # not from trusted network flow = await provider.async_login_flow({'ip_address': '127.0.0.1'}) step = await flow.async_step_init() - assert step['step_id'] == 'init' - assert step['errors']['base'] == 'invalid_auth' + assert step['type'] == 'abort' + assert step['reason'] == 'not_whitelisted' # from trusted network, list users flow = await provider.async_login_flow({'ip_address': '192.168.0.1'}) @@ -95,11 +95,6 @@ async def test_login_flow(manager, provider): with pytest.raises(vol.Invalid): assert schema({'user': 'invalid-user'}) - # login with invalid user - step = await flow.async_step_init({'user': 'invalid-user'}) - assert step['step_id'] == 'init' - assert step['errors']['base'] == 'invalid_auth' - # login with valid user step = await flow.async_step_init({'user': user.id}) assert step['type'] == 'create_entry'