Convert getting and removing access tokens to normal functions (#108670)
parent
904032e944
commit
2eea658fd8
|
@ -458,23 +458,22 @@ class AuthManager:
|
|||
credential,
|
||||
)
|
||||
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str
|
||||
) -> models.RefreshToken | None:
|
||||
@callback
|
||||
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
return await self._store.async_get_refresh_token(token_id)
|
||||
return self._store.async_get_refresh_token(token_id)
|
||||
|
||||
async def async_get_refresh_token_by_token(
|
||||
@callback
|
||||
def async_get_refresh_token_by_token(
|
||||
self, token: str
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token_by_token(token)
|
||||
return self._store.async_get_refresh_token_by_token(token)
|
||||
|
||||
async def async_remove_refresh_token(
|
||||
self, refresh_token: models.RefreshToken
|
||||
) -> None:
|
||||
@callback
|
||||
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||
"""Delete a refresh token."""
|
||||
await self._store.async_remove_refresh_token(refresh_token)
|
||||
self._store.async_remove_refresh_token(refresh_token)
|
||||
|
||||
callbacks = self._revoke_callbacks.pop(refresh_token.id, ())
|
||||
for revoke_callback in callbacks:
|
||||
|
@ -554,16 +553,15 @@ class AuthManager:
|
|||
if provider := self._async_resolve_provider(refresh_token):
|
||||
provider.async_validate_refresh_token(refresh_token, remote_ip)
|
||||
|
||||
async def async_validate_access_token(
|
||||
self, token: str
|
||||
) -> models.RefreshToken | None:
|
||||
@callback
|
||||
def async_validate_access_token(self, token: str) -> models.RefreshToken | None:
|
||||
"""Return refresh token if an access token is valid."""
|
||||
try:
|
||||
unverif_claims = jwt_wrapper.unverified_hs256_token_decode(token)
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
refresh_token = await self.async_get_refresh_token(
|
||||
refresh_token = self.async_get_refresh_token(
|
||||
cast(str, unverif_claims.get("iss"))
|
||||
)
|
||||
|
||||
|
|
|
@ -207,18 +207,16 @@ class AuthStore:
|
|||
self._async_schedule_save()
|
||||
return refresh_token
|
||||
|
||||
async def async_remove_refresh_token(
|
||||
self, refresh_token: models.RefreshToken
|
||||
) -> None:
|
||||
@callback
|
||||
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||
"""Remove a refresh token."""
|
||||
for user in self._users.values():
|
||||
if user.refresh_tokens.pop(refresh_token.id, None):
|
||||
self._async_schedule_save()
|
||||
break
|
||||
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str
|
||||
) -> models.RefreshToken | None:
|
||||
@callback
|
||||
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token_id)
|
||||
|
@ -227,7 +225,8 @@ class AuthStore:
|
|||
|
||||
return None
|
||||
|
||||
async def async_get_refresh_token_by_token(
|
||||
@callback
|
||||
def async_get_refresh_token_by_token(
|
||||
self, token: str
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by token."""
|
||||
|
|
|
@ -124,7 +124,6 @@ as part of a config flow.
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from http import HTTPStatus
|
||||
|
@ -220,12 +219,12 @@ class RevokeTokenView(HomeAssistantView):
|
|||
if (token := data.get("token")) is None:
|
||||
return web.Response(status=HTTPStatus.OK)
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
||||
refresh_token = hass.auth.async_get_refresh_token_by_token(token)
|
||||
|
||||
if refresh_token is None:
|
||||
return web.Response(status=HTTPStatus.OK)
|
||||
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
return web.Response(status=HTTPStatus.OK)
|
||||
|
||||
|
||||
|
@ -355,7 +354,7 @@ class TokenView(HomeAssistantView):
|
|||
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token_by_token(token)
|
||||
refresh_token = hass.auth.async_get_refresh_token_by_token(token)
|
||||
|
||||
if refresh_token is None:
|
||||
return self.json(
|
||||
|
@ -597,7 +596,7 @@ async def websocket_delete_refresh_token(
|
|||
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
|
||||
return
|
||||
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
|
||||
connection.send_result(msg["id"], {})
|
||||
|
||||
|
@ -613,28 +612,23 @@ async def websocket_delete_all_refresh_tokens(
|
|||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle delete all refresh tokens request."""
|
||||
tasks = []
|
||||
current_refresh_token: RefreshToken
|
||||
for token in connection.user.refresh_tokens.values():
|
||||
remove_failed = False
|
||||
for token in list(connection.user.refresh_tokens.values()):
|
||||
if token.id == connection.refresh_token_id:
|
||||
# Skip the current refresh token as it has revoke_callback,
|
||||
# which cancels/closes the connection.
|
||||
# It will be removed after sending the result.
|
||||
current_refresh_token = token
|
||||
continue
|
||||
tasks.append(
|
||||
hass.async_create_task(hass.auth.async_remove_refresh_token(token))
|
||||
)
|
||||
|
||||
remove_failed = False
|
||||
if tasks:
|
||||
for result in await asyncio.gather(*tasks, return_exceptions=True):
|
||||
if isinstance(result, Exception):
|
||||
getLogger(__name__).exception(
|
||||
"During refresh token removal, the following error occurred: %s",
|
||||
result,
|
||||
)
|
||||
remove_failed = True
|
||||
try:
|
||||
hass.auth.async_remove_refresh_token(token)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
getLogger(__name__).exception(
|
||||
"During refresh token removal, the following error occurred: %s",
|
||||
err,
|
||||
)
|
||||
remove_failed = True
|
||||
|
||||
if remove_failed:
|
||||
connection.send_error(
|
||||
|
@ -643,7 +637,8 @@ async def websocket_delete_all_refresh_tokens(
|
|||
else:
|
||||
connection.send_result(msg["id"], {})
|
||||
|
||||
hass.async_create_task(hass.auth.async_remove_refresh_token(current_refresh_token))
|
||||
# This will close the connection so we need to send the result first.
|
||||
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
|
|
|
@ -151,7 +151,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
if auth_type != "Bearer":
|
||||
return False
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(auth_val)
|
||||
refresh_token = hass.auth.async_validate_access_token(auth_val)
|
||||
|
||||
if refresh_token is None:
|
||||
return False
|
||||
|
@ -189,7 +189,7 @@ async def async_setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||
if claims["params"] != params:
|
||||
return False
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(claims["iss"])
|
||||
refresh_token = hass.auth.async_get_refresh_token(claims["iss"])
|
||||
|
||||
if refresh_token is None:
|
||||
return False
|
||||
|
|
|
@ -259,7 +259,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
|||
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id)
|
||||
refresh_token = hass.auth.async_get_refresh_token(refresh_token_id)
|
||||
if refresh_token is None or refresh_token.credential is None:
|
||||
return self.json_message(
|
||||
"Credentials for user not available", HTTPStatus.FORBIDDEN
|
||||
|
|
|
@ -80,9 +80,7 @@ class AuthPhase:
|
|||
raise Disconnect from err
|
||||
|
||||
if (access_token := valid_msg.get("access_token")) and (
|
||||
refresh_token := await self._hass.auth.async_validate_access_token(
|
||||
access_token
|
||||
)
|
||||
refresh_token := self._hass.auth.async_validate_access_token(access_token)
|
||||
):
|
||||
conn = ActiveConnection(
|
||||
self._logger,
|
||||
|
|
|
@ -371,7 +371,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
|
|||
assert refresh_token.client_id == CLIENT_ID
|
||||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
assert await manager.async_validate_access_token(access_token) is refresh_token
|
||||
assert manager.async_validate_access_token(access_token) is refresh_token
|
||||
|
||||
# We patch time directly here because we want the access token to be created with
|
||||
# an expired time, but we do not want to freeze time so that jwt will compare it
|
||||
|
@ -385,7 +385,7 @@ async def test_cannot_retrieve_expired_access_token(hass: HomeAssistant) -> None
|
|||
):
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_generating_system_user(hass: HomeAssistant) -> None:
|
||||
|
@ -572,10 +572,10 @@ async def test_remove_refresh_token(mock_hass) -> None:
|
|||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
manager.async_remove_refresh_token(refresh_token)
|
||||
|
||||
assert await manager.async_get_refresh_token(refresh_token.id) is None
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
assert manager.async_get_refresh_token(refresh_token.id) is None
|
||||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_register_revoke_token_callback(mock_hass) -> None:
|
||||
|
@ -591,7 +591,7 @@ async def test_register_revoke_token_callback(mock_hass) -> None:
|
|||
called = True
|
||||
|
||||
manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
manager.async_remove_refresh_token(refresh_token)
|
||||
assert called
|
||||
|
||||
|
||||
|
@ -610,7 +610,7 @@ async def test_unregister_revoke_token_callback(mock_hass) -> None:
|
|||
unregister = manager.async_register_revoke_token_callback(refresh_token.id, cb)
|
||||
unregister()
|
||||
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
manager.async_remove_refresh_token(refresh_token)
|
||||
assert not called
|
||||
|
||||
|
||||
|
@ -664,7 +664,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
|||
access_token = manager.async_create_access_token(refresh_token)
|
||||
jwt_key = refresh_token.jwt_key
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -675,9 +675,9 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
|||
access_token_expiration=timedelta(days=3000),
|
||||
)
|
||||
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
manager.async_remove_refresh_token(refresh_token)
|
||||
assert refresh_token.id not in user.refresh_tokens
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt is None, "Previous issued access token has been invoked"
|
||||
|
||||
refresh_token_2 = await manager.async_create_refresh_token(
|
||||
|
@ -694,7 +694,7 @@ async def test_one_long_lived_access_token_per_refresh_token(mock_hass) -> None:
|
|||
assert access_token != access_token_2
|
||||
assert jwt_key != jwt_key_2
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token_2)
|
||||
rt = manager.async_validate_access_token(access_token_2)
|
||||
jwt_payload = jwt.decode(access_token_2, rt.jwt_key, algorithms=["HS256"])
|
||||
assert jwt_payload["iss"] == refresh_token_2.id
|
||||
assert (
|
||||
|
@ -1144,7 +1144,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
|||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we corrupt the signature
|
||||
|
@ -1154,7 +1154,7 @@ async def test_access_token_with_invalid_signature(mock_hass) -> None:
|
|||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
result = manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
|
@ -1171,7 +1171,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
|
|||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we make the signature all nulls
|
||||
|
@ -1181,7 +1181,7 @@ async def test_access_token_with_null_signature(mock_hass) -> None:
|
|||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
result = manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
|
@ -1198,7 +1198,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
|
|||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
# Now we make the signature all nulls
|
||||
|
@ -1207,7 +1207,7 @@ async def test_access_token_with_empty_signature(mock_hass) -> None:
|
|||
|
||||
assert access_token != invalid_token
|
||||
|
||||
result = await manager.async_validate_access_token(invalid_token)
|
||||
result = manager.async_validate_access_token(invalid_token)
|
||||
assert result is None
|
||||
|
||||
|
||||
|
@ -1225,17 +1225,17 @@ async def test_access_token_with_empty_key(mock_hass) -> None:
|
|||
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
await manager.async_remove_refresh_token(refresh_token)
|
||||
manager.async_remove_refresh_token(refresh_token)
|
||||
# Now remove the token from the keyring
|
||||
# so we will get an empty key
|
||||
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_reject_access_token_with_impossible_large_size(mock_hass) -> None:
|
||||
"""Test rejecting access tokens with impossible sizes."""
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token("a" * 10000) is None
|
||||
assert manager.async_validate_access_token("a" * 10000) is None
|
||||
|
||||
|
||||
async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
||||
|
@ -1245,7 +1245,7 @@ async def test_reject_token_with_invalid_json_payload(mock_hass) -> None:
|
|||
b"invalid", b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||
)
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token(token_with_invalid_json) is None
|
||||
assert manager.async_validate_access_token(token_with_invalid_json) is None
|
||||
|
||||
|
||||
async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
||||
|
@ -1255,7 +1255,7 @@ async def test_reject_token_with_not_dict_json_payload(mock_hass) -> None:
|
|||
b'["invalid"]', b"invalid", "HS256", {"alg": "HS256", "typ": "JWT"}
|
||||
)
|
||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||
assert await manager.async_validate_access_token(token_not_a_dict_json) is None
|
||||
assert manager.async_validate_access_token(token_not_a_dict_json) is None
|
||||
|
||||
|
||||
async def test_access_token_that_expires_soon(mock_hass) -> None:
|
||||
|
@ -1272,11 +1272,11 @@ async def test_access_token_that_expires_soon(mock_hass) -> None:
|
|||
assert refresh_token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
||||
with freeze_time(now + timedelta(minutes=1)):
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
|
||||
async def test_access_token_from_the_future(mock_hass) -> None:
|
||||
|
@ -1296,8 +1296,8 @@ async def test_access_token_from_the_future(mock_hass) -> None:
|
|||
)
|
||||
access_token = manager.async_create_access_token(refresh_token)
|
||||
|
||||
assert await manager.async_validate_access_token(access_token) is None
|
||||
assert manager.async_validate_access_token(access_token) is None
|
||||
|
||||
with freeze_time(now + timedelta(days=365)):
|
||||
rt = await manager.async_validate_access_token(access_token)
|
||||
rt = manager.async_validate_access_token(access_token)
|
||||
assert rt.id == refresh_token.id
|
||||
|
|
|
@ -588,7 +588,7 @@ async def test_api_fire_event_context(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
assert len(test_value) == 1
|
||||
assert test_value[0].context.user_id == refresh_token.user.id
|
||||
|
@ -606,7 +606,7 @@ async def test_api_call_service_context(
|
|||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].context.user_id == refresh_token.user.id
|
||||
|
@ -622,7 +622,7 @@ async def test_api_set_state_context(
|
|||
headers={"authorization": f"Bearer {hass_access_token}"},
|
||||
)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
state = hass.states.get("light.kitchen")
|
||||
assert state.context.user_id == refresh_token.user.id
|
||||
|
|
|
@ -88,9 +88,7 @@ async def test_login_new_user_and_trying_refresh_token(
|
|||
assert resp.status == HTTPStatus.OK
|
||||
tokens = await resp.json()
|
||||
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
assert tokens["ha_auth_provider"] == "insecure_example"
|
||||
|
||||
# Use refresh token to get more tokens.
|
||||
|
@ -106,9 +104,7 @@ async def test_login_new_user_and_trying_refresh_token(
|
|||
assert resp.status == HTTPStatus.OK
|
||||
tokens = await resp.json()
|
||||
assert "refresh_token" not in tokens
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
# Test using access token to hit API.
|
||||
resp = await client.get("/api/")
|
||||
|
@ -205,7 +201,7 @@ async def test_ws_current_user(
|
|||
"""Test the current user command with Home Assistant creds."""
|
||||
assert await async_setup_component(hass, "auth", {})
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
user = refresh_token.user
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
|
@ -275,9 +271,7 @@ async def test_refresh_token_system_generated(
|
|||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
tokens = await resp.json()
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
|
||||
async def test_refresh_token_different_client_id(
|
||||
|
@ -323,9 +317,7 @@ async def test_refresh_token_different_client_id(
|
|||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
tokens = await resp.json()
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
|
||||
async def test_refresh_token_checks_local_only_user(
|
||||
|
@ -406,16 +398,14 @@ async def test_revoking_refresh_token(
|
|||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
tokens = await resp.json()
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
# Revoke refresh token
|
||||
resp = await client.post(url, data={**base_data, "token": refresh_token.token})
|
||||
assert resp.status == HTTPStatus.OK
|
||||
|
||||
# Old access token should be no longer valid
|
||||
assert await hass.auth.async_validate_access_token(tokens["access_token"]) is None
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is None
|
||||
|
||||
# Test that we no longer can create an access token
|
||||
resp = await client.post(
|
||||
|
@ -454,7 +444,7 @@ async def test_ws_long_lived_access_token(
|
|||
long_lived_access_token = result["result"]
|
||||
assert long_lived_access_token is not None
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(long_lived_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(long_lived_access_token)
|
||||
assert refresh_token.client_id is None
|
||||
assert refresh_token.client_name == "GPS Logger"
|
||||
assert refresh_token.client_icon is None
|
||||
|
@ -474,7 +464,7 @@ async def test_ws_refresh_tokens(
|
|||
assert result["success"], result
|
||||
assert len(result["result"]) == 1
|
||||
token = result["result"][0]
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
assert token["id"] == refresh_token.id
|
||||
assert token["type"] == refresh_token.token_type
|
||||
assert token["client_id"] == refresh_token.client_id
|
||||
|
@ -514,7 +504,7 @@ async def test_ws_delete_refresh_token(
|
|||
|
||||
result = await ws_client.receive_json()
|
||||
assert result["success"], result
|
||||
refresh_token = await hass.auth.async_get_refresh_token(refresh_token.id)
|
||||
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
|
||||
assert refresh_token is None
|
||||
|
||||
|
||||
|
@ -573,7 +563,7 @@ async def test_ws_delete_all_refresh_tokens_error(
|
|||
) in caplog.record_tuples
|
||||
|
||||
for token in tokens:
|
||||
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
|
||||
refresh_token = hass.auth.async_get_refresh_token(token["id"])
|
||||
assert refresh_token is None
|
||||
|
||||
|
||||
|
@ -614,7 +604,7 @@ async def test_ws_delete_all_refresh_tokens(
|
|||
result = await ws_client.receive_json()
|
||||
assert result, result["success"]
|
||||
for token in tokens:
|
||||
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
|
||||
refresh_token = hass.auth.async_get_refresh_token(token["id"])
|
||||
assert refresh_token is None
|
||||
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ async def test_delete_unable_self_account(
|
|||
) -> None:
|
||||
"""Test we cannot delete our own account."""
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
await client.send_json(
|
||||
{"id": 5, "type": auth_config.WS_TYPE_DELETE, "user_id": refresh_token.user.id}
|
||||
|
|
|
@ -211,7 +211,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
|||
token = hass_access_token
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
@ -231,7 +231,7 @@ async def test_auth_active_access_with_access_token_in_header(
|
|||
req = await client.get("/", headers={"Authorization": f"BEARER {token}"})
|
||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||
|
@ -297,7 +297,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
|
@ -325,7 +325,7 @@ async def test_auth_access_signed_path_with_refresh_token(
|
|||
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
# refresh token gone should also invalidate signature
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
req = await client.get(signed_path)
|
||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
@ -342,7 +342,7 @@ async def test_auth_access_signed_path_with_query_param(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, "/?test=test", timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
|
@ -372,7 +372,7 @@ async def test_auth_access_signed_path_with_query_param_order(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass,
|
||||
|
@ -413,7 +413,7 @@ async def test_auth_access_signed_path_with_query_param_safe_param(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass,
|
||||
|
@ -452,7 +452,7 @@ async def test_auth_access_signed_path_with_query_param_tamper(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
signed_path = async_sign_path(
|
||||
hass, base_url, timedelta(seconds=5), refresh_token_id=refresh_token.id
|
||||
|
@ -491,9 +491,7 @@ async def test_auth_access_signed_path_via_websocket(
|
|||
assert msg["id"] == 5
|
||||
assert msg["success"]
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_read_only_access_token
|
||||
)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_read_only_access_token)
|
||||
signature = yarl.URL(msg["result"]["path"]).query["authSig"]
|
||||
claims = jwt.decode(
|
||||
signature,
|
||||
|
@ -523,7 +521,7 @@ async def test_auth_access_signed_path_with_http(
|
|||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
req = await client.get(
|
||||
"/hello", headers={"Authorization": f"Bearer {hass_access_token}"}
|
||||
|
@ -567,7 +565,7 @@ async def test_local_only_user_rejected(
|
|||
await async_setup_auth(hass, app)
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = await aiohttp_client(app)
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
|
|
@ -232,9 +232,7 @@ async def test_onboarding_user(
|
|||
assert resp.status == 200
|
||||
tokens = await resp.json()
|
||||
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
# Validate created areas
|
||||
assert len(area_registry.areas) == 3
|
||||
|
@ -347,9 +345,7 @@ async def test_onboarding_integration(
|
|||
assert const.STEP_INTEGRATION in hass_storage[const.DOMAIN]["data"]["done"]
|
||||
tokens = await resp.json()
|
||||
|
||||
assert (
|
||||
await hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
)
|
||||
assert hass.auth.async_validate_access_token(tokens["access_token"]) is not None
|
||||
|
||||
# Onboarding refresh token and new refresh token
|
||||
user = await hass.auth.async_get_user(hass_admin_user.id)
|
||||
|
@ -368,7 +364,7 @@ async def test_onboarding_integration_missing_credential(
|
|||
assert await async_setup_component(hass, "onboarding", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token.credential = None
|
||||
|
||||
client = await hass_client()
|
||||
|
|
|
@ -134,7 +134,7 @@ async def test_auth_active_user_inactive(
|
|||
hass_access_token: str,
|
||||
) -> None:
|
||||
"""Test authenticating with a token."""
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
@ -216,8 +216,8 @@ async def test_auth_close_after_revoke(
|
|||
"""Test that a websocket is closed after the refresh token is revoked."""
|
||||
assert not websocket_client.closed
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
hass.auth.async_remove_refresh_token(refresh_token)
|
||||
|
||||
msg = await websocket_client.receive()
|
||||
assert msg.type == aiohttp.WSMsgType.CLOSE
|
||||
|
|
|
@ -775,7 +775,7 @@ async def test_call_service_context_with_user(
|
|||
msg = await ws.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
|
|
Loading…
Reference in New Issue