Add ws endpoint to remove expiration date from refresh tokens (#117546)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
pull/118353/head
Robert Resch 2024-05-29 09:09:59 +02:00 committed by GitHub
parent 7e62061b9a
commit e087abe802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 235 additions and 73 deletions

View File

@ -516,6 +516,13 @@ class AuthManager:
for revoke_callback in callbacks:
revoke_callback()
@callback
def async_set_expiry(
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
) -> None:
"""Enable or disable expiry of a refresh token."""
self._store.async_set_expiry(refresh_token, enable_expiry=enable_expiry)
@callback
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
"""Remove expired refresh tokens."""

View File

@ -6,7 +6,6 @@ from datetime import timedelta
import hmac
import itertools
from logging import getLogger
import time
from typing import Any
from homeassistant.core import HomeAssistant, callback
@ -282,6 +281,21 @@ class AuthStore:
)
self._async_schedule_save()
@callback
def async_set_expiry(
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
) -> None:
"""Enable or disable expiry of a refresh token."""
if enable_expiry:
if refresh_token.expire_at is None:
refresh_token.expire_at = (
refresh_token.last_used_at or dt_util.utcnow()
).timestamp() + REFRESH_TOKEN_EXPIRATION
self._async_schedule_save()
else:
refresh_token.expire_at = None
self._async_schedule_save()
async def async_load(self) -> None: # noqa: C901
"""Load the users."""
if self._loaded:
@ -295,8 +309,6 @@ class AuthStore:
perm_lookup = PermissionLookup(ent_reg, dev_reg)
self._perm_lookup = perm_lookup
now_ts = time.time()
if data is None or not isinstance(data, dict):
self._set_defaults()
return
@ -450,14 +462,6 @@ class AuthStore:
else:
last_used_at = None
if (
expire_at := rt_dict.get("expire_at")
) is None and token_type == models.TOKEN_TYPE_NORMAL:
if last_used_at:
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
else:
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
token = models.RefreshToken(
id=rt_dict["id"],
user=users[rt_dict["user_id"]],
@ -474,7 +478,7 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"),
expire_at=expire_at,
expire_at=rt_dict.get("expire_at"),
version=rt_dict.get("version"),
)
if "credential_id" in rt_dict:

View File

@ -197,6 +197,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_delete_refresh_token)
websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens)
websocket_api.async_register_command(hass, websocket_sign_path)
websocket_api.async_register_command(hass, websocket_refresh_token_set_expiry)
login_flow.async_setup(hass, store_result)
mfa_setup_flow.async_setup(hass)
@ -565,18 +566,23 @@ def websocket_refresh_tokens(
else:
auth_provider_type = None
expire_at = None
if refresh.expire_at:
expire_at = dt_util.utc_from_timestamp(refresh.expire_at)
tokens.append(
{
"id": refresh.id,
"auth_provider_type": auth_provider_type,
"client_icon": refresh.client_icon,
"client_id": refresh.client_id,
"client_name": refresh.client_name,
"client_icon": refresh.client_icon,
"type": refresh.token_type,
"created_at": refresh.created_at,
"expire_at": expire_at,
"id": refresh.id,
"is_current": refresh.id == current_id,
"last_used_at": refresh.last_used_at,
"last_used_ip": refresh.last_used_ip,
"auth_provider_type": auth_provider_type,
"type": refresh.token_type,
}
)
@ -702,3 +708,26 @@ def websocket_sign_path(
},
)
)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/refresh_token_set_expiry",
vol.Required("refresh_token_id"): str,
vol.Required("enable_expiry"): bool,
}
)
@websocket_api.ws_require_user()
def websocket_refresh_token_set_expiry(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle a set expiry of a refresh token request."""
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
if refresh_token is None:
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
return
hass.auth.async_set_expiry(refresh_token, enable_expiry=msg["enable_expiry"])
connection.send_result(msg["id"], {})

View File

@ -1,17 +1,14 @@
"""Tests for the auth store."""
import asyncio
from datetime import timedelta
from typing import Any
from unittest.mock import patch
from freezegun import freeze_time
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.auth import auth_store
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
MOCK_STORAGE_DATA = {
"version": 1,
@ -220,68 +217,64 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
assert results[0] == results[1]
async def test_add_expire_at_property(
async def test_dont_change_expire_at_on_load(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test we correctly add expired_at property if not existing."""
now = dt_util.utcnow()
with freeze_time(now):
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
{
"id": "system-id",
"is_active": True,
"is_owner": True,
"name": "Hass.io",
"system_generated": True,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"last_used_at": str(now - timedelta(days=10)),
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id2",
"jwt_key": "some-key2",
"token": "some-token",
"user_id": "user-id",
},
],
},
}
"""Test we correctly don't modify expired_at store load."""
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
{
"id": "system-id",
"is_active": True,
"is_owner": True,
"name": "Hass.io",
"system_generated": True,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id2",
"jwt_key": "some-key2",
"token": "some-token",
"user_id": "user-id",
"expire_at": 1724133771.079745,
},
],
},
}
store = auth_store.AuthStore(hass)
await store.async_load()
store = auth_store.AuthStore(hass)
await store.async_load()
users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 2
token1, token2 = users[0].refresh_tokens.values()
assert token1.expire_at
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
assert token2.expire_at
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
assert not token1.expire_at
assert token2.expire_at == 1724133771.079745
async def test_loading_does_not_write_right_away(
@ -326,3 +319,63 @@ async def test_add_remove_user_affects_tokens(
assert store.async_get_refresh_token(refresh_token.id) is None
assert store.async_get_refresh_token_by_token(refresh_token.token) is None
assert user.refresh_tokens == {}
async def test_set_expiry_date(
hass: HomeAssistant, hass_storage: dict[str, Any], freezer: FrozenDateTimeFactory
) -> None:
"""Test set expiry date of a refresh token."""
hass_storage[auth_store.STORAGE_KEY] = {
"version": 1,
"data": {
"credentials": [],
"users": [
{
"id": "user-id",
"is_active": True,
"is_owner": True,
"name": "Paulus",
"system_generated": False,
},
],
"refresh_tokens": [
{
"access_token_expiration": 1800.0,
"client_id": "http://localhost:8123/",
"created_at": "2018-10-03T13:43:19.774637+00:00",
"id": "user-token-id",
"jwt_key": "some-key",
"token": "some-token",
"user_id": "user-id",
"expire_at": 1724133771.079745,
},
],
},
}
store = auth_store.AuthStore(hass)
await store.async_load()
users = await store.async_get_users()
assert len(users[0].refresh_tokens) == 1
(token,) = users[0].refresh_tokens.values()
assert token.expire_at == 1724133771.079745
store.async_set_expiry(token, enable_expiry=False)
assert token.expire_at is None
freezer.tick(auth_store.DEFAULT_SAVE_DELAY * 2)
# Once for scheduling the task
await hass.async_block_till_done()
# Once for the task
await hass.async_block_till_done()
# verify token is saved without expire_at
assert (
hass_storage[auth_store.STORAGE_KEY]["data"]["refresh_tokens"][0]["expire_at"]
is None
)
store.async_set_expiry(token, enable_expiry=True)
assert token.expire_at is not None

View File

@ -690,3 +690,72 @@ async def test_ws_sign_path(
hass, path, expires = mock_sign.mock_calls[0][1]
assert path == "/api/hello"
assert expires.total_seconds() == 20
async def test_ws_refresh_token_set_expiry(
hass: HomeAssistant,
hass_admin_user: MockUser,
hass_admin_credential: Credentials,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
) -> None:
"""Test setting expiry of a refresh token."""
assert await async_setup_component(hass, "auth", {"http": {}})
refresh_token = await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
assert refresh_token.expire_at is not None
ws_client = await hass_ws_client(hass, hass_access_token)
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": refresh_token.id,
"enable_expiry": False,
}
)
result = await ws_client.receive_json()
assert result["success"], result
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token.expire_at is None
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": refresh_token.id,
"enable_expiry": True,
}
)
result = await ws_client.receive_json()
assert result["success"], result
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
assert refresh_token.expire_at is not None
async def test_ws_refresh_token_set_expiry_error(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
) -> None:
"""Test setting expiry of a invalid refresh token returns error."""
assert await async_setup_component(hass, "auth", {"http": {}})
ws_client = await hass_ws_client(hass, hass_access_token)
await ws_client.send_json_auto_id(
{
"type": "auth/refresh_token_set_expiry",
"refresh_token_id": "invalid",
"enable_expiry": False,
}
)
result = await ws_client.receive_json()
assert result, result["success"] is False
assert result["error"] == {
"code": "invalid_token_id",
"message": "Received invalid token",
}