Add ws endpoint to remove expiration date from refresh tokens (#117546)
Co-authored-by: Erik Montnemery <erik@montnemery.com>pull/118353/head
parent
7e62061b9a
commit
e087abe802
|
@ -516,6 +516,13 @@ class AuthManager:
|
|||
for revoke_callback in callbacks:
|
||||
revoke_callback()
|
||||
|
||||
@callback
|
||||
def async_set_expiry(
|
||||
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
|
||||
) -> None:
|
||||
"""Enable or disable expiry of a refresh token."""
|
||||
self._store.async_set_expiry(refresh_token, enable_expiry=enable_expiry)
|
||||
|
||||
@callback
|
||||
def _async_remove_expired_refresh_tokens(self, _: datetime | None = None) -> None:
|
||||
"""Remove expired refresh tokens."""
|
||||
|
|
|
@ -6,7 +6,6 @@ from datetime import timedelta
|
|||
import hmac
|
||||
import itertools
|
||||
from logging import getLogger
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
@ -282,6 +281,21 @@ class AuthStore:
|
|||
)
|
||||
self._async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_set_expiry(
|
||||
self, refresh_token: models.RefreshToken, *, enable_expiry: bool
|
||||
) -> None:
|
||||
"""Enable or disable expiry of a refresh token."""
|
||||
if enable_expiry:
|
||||
if refresh_token.expire_at is None:
|
||||
refresh_token.expire_at = (
|
||||
refresh_token.last_used_at or dt_util.utcnow()
|
||||
).timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||
self._async_schedule_save()
|
||||
else:
|
||||
refresh_token.expire_at = None
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_load(self) -> None: # noqa: C901
|
||||
"""Load the users."""
|
||||
if self._loaded:
|
||||
|
@ -295,8 +309,6 @@ class AuthStore:
|
|||
perm_lookup = PermissionLookup(ent_reg, dev_reg)
|
||||
self._perm_lookup = perm_lookup
|
||||
|
||||
now_ts = time.time()
|
||||
|
||||
if data is None or not isinstance(data, dict):
|
||||
self._set_defaults()
|
||||
return
|
||||
|
@ -450,14 +462,6 @@ class AuthStore:
|
|||
else:
|
||||
last_used_at = None
|
||||
|
||||
if (
|
||||
expire_at := rt_dict.get("expire_at")
|
||||
) is None and token_type == models.TOKEN_TYPE_NORMAL:
|
||||
if last_used_at:
|
||||
expire_at = last_used_at.timestamp() + REFRESH_TOKEN_EXPIRATION
|
||||
else:
|
||||
expire_at = now_ts + REFRESH_TOKEN_EXPIRATION
|
||||
|
||||
token = models.RefreshToken(
|
||||
id=rt_dict["id"],
|
||||
user=users[rt_dict["user_id"]],
|
||||
|
@ -474,7 +478,7 @@ class AuthStore:
|
|||
jwt_key=rt_dict["jwt_key"],
|
||||
last_used_at=last_used_at,
|
||||
last_used_ip=rt_dict.get("last_used_ip"),
|
||||
expire_at=expire_at,
|
||||
expire_at=rt_dict.get("expire_at"),
|
||||
version=rt_dict.get("version"),
|
||||
)
|
||||
if "credential_id" in rt_dict:
|
||||
|
|
|
@ -197,6 +197,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
websocket_api.async_register_command(hass, websocket_delete_refresh_token)
|
||||
websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens)
|
||||
websocket_api.async_register_command(hass, websocket_sign_path)
|
||||
websocket_api.async_register_command(hass, websocket_refresh_token_set_expiry)
|
||||
|
||||
login_flow.async_setup(hass, store_result)
|
||||
mfa_setup_flow.async_setup(hass)
|
||||
|
@ -565,18 +566,23 @@ def websocket_refresh_tokens(
|
|||
else:
|
||||
auth_provider_type = None
|
||||
|
||||
expire_at = None
|
||||
if refresh.expire_at:
|
||||
expire_at = dt_util.utc_from_timestamp(refresh.expire_at)
|
||||
|
||||
tokens.append(
|
||||
{
|
||||
"id": refresh.id,
|
||||
"auth_provider_type": auth_provider_type,
|
||||
"client_icon": refresh.client_icon,
|
||||
"client_id": refresh.client_id,
|
||||
"client_name": refresh.client_name,
|
||||
"client_icon": refresh.client_icon,
|
||||
"type": refresh.token_type,
|
||||
"created_at": refresh.created_at,
|
||||
"expire_at": expire_at,
|
||||
"id": refresh.id,
|
||||
"is_current": refresh.id == current_id,
|
||||
"last_used_at": refresh.last_used_at,
|
||||
"last_used_ip": refresh.last_used_ip,
|
||||
"auth_provider_type": auth_provider_type,
|
||||
"type": refresh.token_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -702,3 +708,26 @@ def websocket_sign_path(
|
|||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "auth/refresh_token_set_expiry",
|
||||
vol.Required("refresh_token_id"): str,
|
||||
vol.Required("enable_expiry"): bool,
|
||||
}
|
||||
)
|
||||
@websocket_api.ws_require_user()
|
||||
def websocket_refresh_token_set_expiry(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle a set expiry of a refresh token request."""
|
||||
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
|
||||
|
||||
if refresh_token is None:
|
||||
connection.send_error(msg["id"], "invalid_token_id", "Received invalid token")
|
||||
return
|
||||
|
||||
hass.auth.async_set_expiry(refresh_token, enable_expiry=msg["enable_expiry"])
|
||||
connection.send_result(msg["id"], {})
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
"""Tests for the auth store."""
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
|
||||
from homeassistant.auth import auth_store
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
MOCK_STORAGE_DATA = {
|
||||
"version": 1,
|
||||
|
@ -220,68 +217,64 @@ async def test_loading_only_once(hass: HomeAssistant) -> None:
|
|||
assert results[0] == results[1]
|
||||
|
||||
|
||||
async def test_add_expire_at_property(
|
||||
async def test_dont_change_expire_at_on_load(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test we correctly add expired_at property if not existing."""
|
||||
now = dt_util.utcnow()
|
||||
with freeze_time(now):
|
||||
hass_storage[auth_store.STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {
|
||||
"credentials": [],
|
||||
"users": [
|
||||
{
|
||||
"id": "user-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Paulus",
|
||||
"system_generated": False,
|
||||
},
|
||||
{
|
||||
"id": "system-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Hass.io",
|
||||
"system_generated": True,
|
||||
},
|
||||
],
|
||||
"refresh_tokens": [
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id",
|
||||
"jwt_key": "some-key",
|
||||
"last_used_at": str(now - timedelta(days=10)),
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id2",
|
||||
"jwt_key": "some-key2",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
"""Test we correctly don't modify expired_at store load."""
|
||||
hass_storage[auth_store.STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {
|
||||
"credentials": [],
|
||||
"users": [
|
||||
{
|
||||
"id": "user-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Paulus",
|
||||
"system_generated": False,
|
||||
},
|
||||
{
|
||||
"id": "system-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Hass.io",
|
||||
"system_generated": True,
|
||||
},
|
||||
],
|
||||
"refresh_tokens": [
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id",
|
||||
"jwt_key": "some-key",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"version": "1.2.3",
|
||||
},
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id2",
|
||||
"jwt_key": "some-key2",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"expire_at": 1724133771.079745,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
|
||||
users = await store.async_get_users()
|
||||
|
||||
assert len(users[0].refresh_tokens) == 2
|
||||
token1, token2 = users[0].refresh_tokens.values()
|
||||
assert token1.expire_at
|
||||
assert token1.expire_at == now.timestamp() + timedelta(days=80).total_seconds()
|
||||
assert token2.expire_at
|
||||
assert token2.expire_at == now.timestamp() + timedelta(days=90).total_seconds()
|
||||
assert not token1.expire_at
|
||||
assert token2.expire_at == 1724133771.079745
|
||||
|
||||
|
||||
async def test_loading_does_not_write_right_away(
|
||||
|
@ -326,3 +319,63 @@ async def test_add_remove_user_affects_tokens(
|
|||
assert store.async_get_refresh_token(refresh_token.id) is None
|
||||
assert store.async_get_refresh_token_by_token(refresh_token.token) is None
|
||||
assert user.refresh_tokens == {}
|
||||
|
||||
|
||||
async def test_set_expiry_date(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any], freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test set expiry date of a refresh token."""
|
||||
hass_storage[auth_store.STORAGE_KEY] = {
|
||||
"version": 1,
|
||||
"data": {
|
||||
"credentials": [],
|
||||
"users": [
|
||||
{
|
||||
"id": "user-id",
|
||||
"is_active": True,
|
||||
"is_owner": True,
|
||||
"name": "Paulus",
|
||||
"system_generated": False,
|
||||
},
|
||||
],
|
||||
"refresh_tokens": [
|
||||
{
|
||||
"access_token_expiration": 1800.0,
|
||||
"client_id": "http://localhost:8123/",
|
||||
"created_at": "2018-10-03T13:43:19.774637+00:00",
|
||||
"id": "user-token-id",
|
||||
"jwt_key": "some-key",
|
||||
"token": "some-token",
|
||||
"user_id": "user-id",
|
||||
"expire_at": 1724133771.079745,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
|
||||
users = await store.async_get_users()
|
||||
|
||||
assert len(users[0].refresh_tokens) == 1
|
||||
(token,) = users[0].refresh_tokens.values()
|
||||
assert token.expire_at == 1724133771.079745
|
||||
|
||||
store.async_set_expiry(token, enable_expiry=False)
|
||||
assert token.expire_at is None
|
||||
|
||||
freezer.tick(auth_store.DEFAULT_SAVE_DELAY * 2)
|
||||
# Once for scheduling the task
|
||||
await hass.async_block_till_done()
|
||||
# Once for the task
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# verify token is saved without expire_at
|
||||
assert (
|
||||
hass_storage[auth_store.STORAGE_KEY]["data"]["refresh_tokens"][0]["expire_at"]
|
||||
is None
|
||||
)
|
||||
|
||||
store.async_set_expiry(token, enable_expiry=True)
|
||||
assert token.expire_at is not None
|
||||
|
|
|
@ -690,3 +690,72 @@ async def test_ws_sign_path(
|
|||
hass, path, expires = mock_sign.mock_calls[0][1]
|
||||
assert path == "/api/hello"
|
||||
assert expires.total_seconds() == 20
|
||||
|
||||
|
||||
async def test_ws_refresh_token_set_expiry(
|
||||
hass: HomeAssistant,
|
||||
hass_admin_user: MockUser,
|
||||
hass_admin_credential: Credentials,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass_access_token: str,
|
||||
) -> None:
|
||||
"""Test setting expiry of a refresh token."""
|
||||
assert await async_setup_component(hass, "auth", {"http": {}})
|
||||
|
||||
refresh_token = await hass.auth.async_create_refresh_token(
|
||||
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
|
||||
)
|
||||
assert refresh_token.expire_at is not None
|
||||
ws_client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "auth/refresh_token_set_expiry",
|
||||
"refresh_token_id": refresh_token.id,
|
||||
"enable_expiry": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = await ws_client.receive_json()
|
||||
assert result["success"], result
|
||||
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
|
||||
assert refresh_token.expire_at is None
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "auth/refresh_token_set_expiry",
|
||||
"refresh_token_id": refresh_token.id,
|
||||
"enable_expiry": True,
|
||||
}
|
||||
)
|
||||
|
||||
result = await ws_client.receive_json()
|
||||
assert result["success"], result
|
||||
refresh_token = hass.auth.async_get_refresh_token(refresh_token.id)
|
||||
assert refresh_token.expire_at is not None
|
||||
|
||||
|
||||
async def test_ws_refresh_token_set_expiry_error(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass_access_token: str,
|
||||
) -> None:
|
||||
"""Test setting expiry of a invalid refresh token returns error."""
|
||||
assert await async_setup_component(hass, "auth", {"http": {}})
|
||||
|
||||
ws_client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "auth/refresh_token_set_expiry",
|
||||
"refresh_token_id": "invalid",
|
||||
"enable_expiry": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = await ws_client.receive_json()
|
||||
assert result, result["success"] is False
|
||||
assert result["error"] == {
|
||||
"code": "invalid_token_id",
|
||||
"message": "Received invalid token",
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue