Allow delete_all_refresh_tokens to delete a specific token_type (#106119)
* Allow delete_all_refresh_tokens to delete a specific token_type * add a test * minor string change * test updates * more test updates * more test updates * fix tests * do not delete current token * Update tests/components/auth/test_init.py * Update tests/components/auth/test_init.py * Option to not delete the current token --------- Co-authored-by: J. Nick Koston <nick@koston.org>pull/105955/head
parent
075dab250e
commit
f456e3a071
|
@ -604,6 +604,8 @@ async def websocket_delete_refresh_token(
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
{
|
{
|
||||||
vol.Required("type"): "auth/delete_all_refresh_tokens",
|
vol.Required("type"): "auth/delete_all_refresh_tokens",
|
||||||
|
vol.Optional("token_type"): cv.string,
|
||||||
|
vol.Optional("delete_current_token", default=True): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@websocket_api.ws_require_user()
|
@websocket_api.ws_require_user()
|
||||||
|
@ -614,6 +616,10 @@ async def websocket_delete_all_refresh_tokens(
|
||||||
"""Handle delete all refresh tokens request."""
|
"""Handle delete all refresh tokens request."""
|
||||||
current_refresh_token: RefreshToken
|
current_refresh_token: RefreshToken
|
||||||
remove_failed = False
|
remove_failed = False
|
||||||
|
token_type = msg.get("token_type")
|
||||||
|
delete_current_token = msg.get("delete_current_token")
|
||||||
|
limit_token_types = token_type is not None
|
||||||
|
|
||||||
for token in list(connection.user.refresh_tokens.values()):
|
for token in list(connection.user.refresh_tokens.values()):
|
||||||
if token.id == connection.refresh_token_id:
|
if token.id == connection.refresh_token_id:
|
||||||
# Skip the current refresh token as it has revoke_callback,
|
# Skip the current refresh token as it has revoke_callback,
|
||||||
|
@ -621,6 +627,8 @@ async def websocket_delete_all_refresh_tokens(
|
||||||
# It will be removed after sending the result.
|
# It will be removed after sending the result.
|
||||||
current_refresh_token = token
|
current_refresh_token = token
|
||||||
continue
|
continue
|
||||||
|
if limit_token_types and token_type != token.token_type:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
hass.auth.async_remove_refresh_token(token)
|
hass.auth.async_remove_refresh_token(token)
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
@ -637,8 +645,11 @@ async def websocket_delete_all_refresh_tokens(
|
||||||
else:
|
else:
|
||||||
connection.send_result(msg["id"], {})
|
connection.send_result(msg["id"], {})
|
||||||
|
|
||||||
# This will close the connection so we need to send the result first.
|
if delete_current_token and (
|
||||||
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)
|
not limit_token_types or current_refresh_token.token_type == token_type
|
||||||
|
):
|
||||||
|
# This will close the connection so we need to send the result first.
|
||||||
|
hass.loop.call_soon(hass.auth.async_remove_refresh_token, current_refresh_token)
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.websocket_command(
|
@websocket_api.websocket_command(
|
||||||
|
|
|
@ -8,7 +8,11 @@ from freezegun.api import FrozenDateTimeFactory
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.auth import InvalidAuthError
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.models import (
|
||||||
|
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
TOKEN_TYPE_NORMAL,
|
||||||
|
Credentials,
|
||||||
|
)
|
||||||
from homeassistant.components import auth
|
from homeassistant.components import auth
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
@ -567,22 +571,50 @@ async def test_ws_delete_all_refresh_tokens_error(
|
||||||
assert refresh_token is None
|
assert refresh_token is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
(
|
||||||
|
"delete_token_type",
|
||||||
|
"delete_current_token",
|
||||||
|
"expected_remaining_normal_tokens",
|
||||||
|
"expected_remaining_long_lived_tokens",
|
||||||
|
),
|
||||||
|
[
|
||||||
|
({}, {}, 0, 0),
|
||||||
|
({"token_type": TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN}, {}, 3, 0),
|
||||||
|
({"token_type": TOKEN_TYPE_NORMAL}, {}, 0, 1),
|
||||||
|
({"token_type": TOKEN_TYPE_NORMAL}, {"delete_current_token": False}, 1, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_ws_delete_all_refresh_tokens(
|
async def test_ws_delete_all_refresh_tokens(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_admin_user: MockUser,
|
hass_admin_user: MockUser,
|
||||||
hass_admin_credential: Credentials,
|
hass_admin_credential: Credentials,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
hass_access_token: str,
|
hass_access_token: str,
|
||||||
|
delete_token_type: dict[str:str],
|
||||||
|
delete_current_token: dict[str:bool],
|
||||||
|
expected_remaining_normal_tokens: int,
|
||||||
|
expected_remaining_long_lived_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test deleting all refresh tokens."""
|
"""Test deleting all or some refresh tokens."""
|
||||||
assert await async_setup_component(hass, "auth", {"http": {}})
|
assert await async_setup_component(hass, "auth", {"http": {}})
|
||||||
|
|
||||||
# one token already exists
|
# one token already exists
|
||||||
await hass.auth.async_create_refresh_token(
|
await hass.auth.async_create_refresh_token(
|
||||||
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
|
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# create a long lived token
|
||||||
await hass.auth.async_create_refresh_token(
|
await hass.auth.async_create_refresh_token(
|
||||||
hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential
|
hass_admin_user,
|
||||||
|
f"{CLIENT_ID}_LL",
|
||||||
|
client_name="client_ll",
|
||||||
|
credential=hass_admin_credential,
|
||||||
|
token_type=TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.auth.async_create_refresh_token(
|
||||||
|
hass_admin_user, f"{CLIENT_ID}_1", credential=hass_admin_credential
|
||||||
)
|
)
|
||||||
|
|
||||||
ws_client = await hass_ws_client(hass, hass_access_token)
|
ws_client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
@ -592,20 +624,35 @@ async def test_ws_delete_all_refresh_tokens(
|
||||||
result = await ws_client.receive_json()
|
result = await ws_client.receive_json()
|
||||||
assert result["success"], result
|
assert result["success"], result
|
||||||
|
|
||||||
tokens = result["result"]
|
|
||||||
|
|
||||||
await ws_client.send_json(
|
await ws_client.send_json(
|
||||||
{
|
{
|
||||||
"id": 6,
|
"id": 6,
|
||||||
"type": "auth/delete_all_refresh_tokens",
|
"type": "auth/delete_all_refresh_tokens",
|
||||||
|
**delete_token_type,
|
||||||
|
**delete_current_token,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await ws_client.receive_json()
|
result = await ws_client.receive_json()
|
||||||
assert result, result["success"]
|
assert result, result["success"]
|
||||||
for token in tokens:
|
|
||||||
refresh_token = hass.auth.async_get_refresh_token(token["id"])
|
# We need to enumerate the user since we may remove the token
|
||||||
assert refresh_token is None
|
# that is used to authenticate the user which will prevent the websocket
|
||||||
|
# connection from working
|
||||||
|
remaining_tokens_by_type: dict[str, int] = {
|
||||||
|
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN: 0,
|
||||||
|
TOKEN_TYPE_NORMAL: 0,
|
||||||
|
}
|
||||||
|
for refresh_token in hass_admin_user.refresh_tokens.values():
|
||||||
|
remaining_tokens_by_type[refresh_token.token_type] += 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
remaining_tokens_by_type[TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN]
|
||||||
|
== expected_remaining_long_lived_tokens
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
remaining_tokens_by_type[TOKEN_TYPE_NORMAL] == expected_remaining_normal_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_ws_sign_path(
|
async def test_ws_sign_path(
|
||||||
|
|
Loading…
Reference in New Issue