diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index bcdcf4de747..359f67ed0a5 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -124,11 +124,7 @@ from aiohttp import web import voluptuous as vol from homeassistant.auth import InvalidAuthError -from homeassistant.auth.models import ( - TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, - Credentials, - User, -) +from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials from homeassistant.components import websocket_api from homeassistant.components.http.auth import async_sign_path from homeassistant.components.http.ban import log_invalid_auth @@ -179,15 +175,12 @@ SCHEMA_WS_SIGN_PATH = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( ) RESULT_TYPE_CREDENTIALS = "credentials" -RESULT_TYPE_USER = "user" @bind_hass -def create_auth_code( - hass, client_id: str, credential_or_user: Credentials | User -) -> str: +def create_auth_code(hass, client_id: str, credential: Credentials) -> str: """Create an authorization code to fetch tokens.""" - return hass.data[DOMAIN](client_id, credential_or_user) + return hass.data[DOMAIN](client_id, credential) async def async_setup(hass, config): @@ -296,7 +289,7 @@ class TokenView(HomeAssistantView): status_code=HTTPStatus.BAD_REQUEST, ) - credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code) + credential = self._retrieve_auth(client_id, code) if credential is None or not isinstance(credential, Credentials): return self.json( @@ -399,9 +392,7 @@ class LinkUserView(HomeAssistantView): hass = request.app["hass"] user = request["hass_user"] - credentials = self._retrieve_credentials( - data["client_id"], RESULT_TYPE_CREDENTIALS, data["code"] - ) + credentials = self._retrieve_credentials(data["client_id"], data["code"]) if credentials is None: return self.json_message("Invalid code", status_code=HTTPStatus.BAD_REQUEST) @@ -426,30 +417,25 @@ def _create_auth_code_store(): @callback def store_result(client_id, result): """Store flow result and return a code to retrieve it.""" - if isinstance(result, User): - result_type = RESULT_TYPE_USER - elif isinstance(result, Credentials): - result_type = RESULT_TYPE_CREDENTIALS - else: - raise ValueError("result has to be either User or Credentials") + if not isinstance(result, Credentials): + raise ValueError("result has to be a Credentials instance") code = uuid.uuid4().hex - temp_results[(client_id, result_type, code)] = ( + temp_results[(client_id, code)] = ( dt_util.utcnow(), - result_type, result, ) return code @callback - def retrieve_result(client_id, result_type, code): + def retrieve_result(client_id, code): """Retrieve flow result.""" - key = (client_id, result_type, code) + key = (client_id, code) if key not in temp_results: return None - created, _, result = temp_results.pop(key) + created, result = temp_results.pop(key) # OAuth 4.2.1 # The authorization code MUST expire shortly after it is issued to diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 2c96f545b41..53cc291a5db 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -3,10 +3,11 @@ from datetime import timedelta from http import HTTPStatus from unittest.mock import patch +import pytest + from homeassistant.auth import InvalidAuthError from homeassistant.auth.models import Credentials from homeassistant.components import auth -from homeassistant.components.auth import RESULT_TYPE_USER from homeassistant.setup import async_setup_component from homeassistant.util.dt import utcnow @@ -15,6 +16,18 @@ from . import async_setup_auth from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser +@pytest.fixture +def mock_credential(): + """Return a mock credential.""" + return Credentials( + id="mock-credential-id", + auth_provider_type="insecure_example", + auth_provider_id=None, + data={"username": "test-user"}, + is_new=False, + ) + + async def async_setup_user_refresh_token(hass): """Create a testing user with a connected credential.""" user = await hass.auth.async_create_user("Test User") @@ -96,29 +109,38 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client): assert resp.status == HTTPStatus.OK -def test_auth_code_store_expiration(): +def test_auth_code_store_expiration(mock_credential): """Test that the auth code store will not return expired tokens.""" store, retrieve = auth._create_auth_code_store() client_id = "bla" - user = MockUser(id="mock_user") now = utcnow() with patch("homeassistant.util.dt.utcnow", return_value=now): - code = store(client_id, user) + code = store(client_id, mock_credential) with patch( "homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=10) ): - assert retrieve(client_id, RESULT_TYPE_USER, code) is None + assert retrieve(client_id, code) is None with patch("homeassistant.util.dt.utcnow", return_value=now): - code = store(client_id, user) + code = store(client_id, mock_credential) with patch( "homeassistant.util.dt.utcnow", return_value=now + timedelta(minutes=9, seconds=59), ): - assert retrieve(client_id, RESULT_TYPE_USER, code) == user + assert retrieve(client_id, code) == mock_credential + + +def test_auth_code_store_requires_credentials(mock_credential): + """Test we require credentials.""" + store, _retrieve = auth._create_auth_code_store() + + with pytest.raises(ValueError): + store(None, MockUser()) + + store(None, mock_credential) async def test_ws_current_user(hass, hass_ws_client, hass_access_token):