Allow delete_all_refresh_tokens to delete a specific token_type (#106119)

* Allow delete_all_refresh_tokens to delete a specific token_type

* add a test

* minor string change

* test updates

* more test updates

* more test updates

* fix tests

* do not delete current token

* Update tests/components/auth/test_init.py

* Update tests/components/auth/test_init.py

* Option to not delete the current token

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
pull/105955/head
karwosts 2024-01-29 11:09:23 -05:00 committed by GitHub
parent 075dab250e
commit f456e3a071
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 10 deletions

View File

@ -604,6 +604,8 @@ async def websocket_delete_refresh_token(
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/delete_all_refresh_tokens",
vol.Optional("token_type"): cv.string,
vol.Optional("delete_current_token", default=True): bool,
}
)
@websocket_api.ws_require_user()
@ -614,6 +616,10 @@ async def websocket_delete_all_refresh_tokens(
"""Handle delete all refresh tokens request."""
current_refresh_token: RefreshToken
remove_failed = False
token_type = msg.get("token_type")
delete_current_token = msg.get("delete_current_token")
limit_token_types = token_type is not None
for token in list(connection.user.refresh_tokens.values()):
if token.id == connection.refresh_token_id:
# Skip the current refresh token as it has revoke_callback,
@ -621,6 +627,8 @@ async def websocket_delete_all_refresh_tokens(
# It will be removed after sending the result.
current_refresh_token = token
continue
if limit_token_types and token_type != token.token_type:
continue
try:
hass.auth.async_remove_refresh_token(token)
except Exception as err: # pylint: disable=broad-except
@ -637,6 +645,9 @@ async def websocket_delete_all_refresh_tokens(
else:
connection.send_result(msg["id"], {})
if delete_current_token and (
not limit_token_types or current_refresh_token.token_type == token_type
):
# This will close the connection so we need to send the result first.
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)

View File

@ -8,7 +8,11 @@ from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import Credentials
from homeassistant.auth.models import (
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
TOKEN_TYPE_NORMAL,
Credentials,
)
from homeassistant.components import auth
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -567,22 +571,50 @@ async def test_ws_delete_all_refresh_tokens_error(
assert refresh_token is None
@pytest.mark.parametrize(
(
"delete_token_type",
"delete_current_token",
"expected_remaining_normal_tokens",
"expected_remaining_long_lived_tokens",
),
[
({}, {}, 0, 0),
({"token_type": TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN}, {}, 3, 0),
({"token_type": TOKEN_TYPE_NORMAL}, {}, 0, 1),
({"token_type": TOKEN_TYPE_NORMAL}, {"delete_current_token": False}, 1, 1),
],
)
async def test_ws_delete_all_refresh_tokens(
hass: HomeAssistant,
hass_admin_user: MockUser,
hass_admin_credential: Credentials,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
delete_token_type: dict[str:str],
delete_current_token: dict[str:bool],
expected_remaining_normal_tokens: int,
expected_remaining_long_lived_tokens: int,
) -> None:
"""Test deleting all refresh tokens."""
"""Test deleting all or some refresh tokens."""
assert await async_setup_component(hass, "auth", {"http": {}})
# one token already exists
await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
# create a long lived token
await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential
hass_admin_user,
f"{CLIENT_ID}_LL",
client_name="client_ll",
credential=hass_admin_credential,
token_type=TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
)
await hass.auth.async_create_refresh_token(
hass_admin_user, f"{CLIENT_ID}_1", credential=hass_admin_credential
)
ws_client = await hass_ws_client(hass, hass_access_token)
@ -592,20 +624,35 @@ async def test_ws_delete_all_refresh_tokens(
result = await ws_client.receive_json()
assert result["success"], result
tokens = result["result"]
await ws_client.send_json(
{
"id": 6,
"type": "auth/delete_all_refresh_tokens",
**delete_token_type,
**delete_current_token,
}
)
result = await ws_client.receive_json()
assert result, result["success"]
for token in tokens:
refresh_token = hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None
# We need to enumerate the user since we may remove the token
# that is used to authenticate the user which will prevent the websocket
# connection from working
remaining_tokens_by_type: dict[str, int] = {
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN: 0,
TOKEN_TYPE_NORMAL: 0,
}
for refresh_token in hass_admin_user.refresh_tokens.values():
remaining_tokens_by_type[refresh_token.token_type] += 1
assert (
remaining_tokens_by_type[TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN]
== expected_remaining_long_lived_tokens
)
assert (
remaining_tokens_by_type[TOKEN_TYPE_NORMAL] == expected_remaining_normal_tokens
)
async def test_ws_sign_path(