Index auth token ids to avoid linear search (#116583)
* Index auth token ids to avoid linear search * async_remove_refresh_token * coveragepull/116891/head
parent
c8e6292cb7
commit
a57f4b8f42
|
@ -63,6 +63,7 @@ class AuthStore:
|
|||
self._store = Store[dict[str, list[dict[str, Any]]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._token_id_to_user_id: dict[str, str] = {}
|
||||
|
||||
async def async_get_groups(self) -> list[models.Group]:
|
||||
"""Retrieve all users."""
|
||||
|
@ -136,7 +137,10 @@ class AuthStore:
|
|||
|
||||
async def async_remove_user(self, user: models.User) -> None:
|
||||
"""Remove a user."""
|
||||
self._users.pop(user.id)
|
||||
user = self._users.pop(user.id)
|
||||
for refresh_token_id in user.refresh_tokens:
|
||||
del self._token_id_to_user_id[refresh_token_id]
|
||||
user.refresh_tokens.clear()
|
||||
self._async_schedule_save()
|
||||
|
||||
async def async_update_user(
|
||||
|
@ -219,7 +223,9 @@ class AuthStore:
|
|||
kwargs["client_icon"] = client_icon
|
||||
|
||||
refresh_token = models.RefreshToken(**kwargs)
|
||||
user.refresh_tokens[refresh_token.id] = refresh_token
|
||||
token_id = refresh_token.id
|
||||
user.refresh_tokens[token_id] = refresh_token
|
||||
self._token_id_to_user_id[token_id] = user.id
|
||||
|
||||
self._async_schedule_save()
|
||||
return refresh_token
|
||||
|
@ -227,19 +233,17 @@ class AuthStore:
|
|||
@callback
|
||||
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
|
||||
"""Remove a refresh token."""
|
||||
for user in self._users.values():
|
||||
if user.refresh_tokens.pop(refresh_token.id, None):
|
||||
self._async_schedule_save()
|
||||
break
|
||||
refresh_token_id = refresh_token.id
|
||||
if user_id := self._token_id_to_user_id.get(refresh_token_id):
|
||||
del self._users[user_id].refresh_tokens[refresh_token_id]
|
||||
del self._token_id_to_user_id[refresh_token_id]
|
||||
self._async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token_id)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
|
||||
if user_id := self._token_id_to_user_id.get(token_id):
|
||||
return self._users[user_id].refresh_tokens.get(token_id)
|
||||
return None
|
||||
|
||||
@callback
|
||||
|
@ -479,9 +483,18 @@ class AuthStore:
|
|||
|
||||
self._groups = groups
|
||||
self._users = users
|
||||
|
||||
self._build_token_id_to_user_id()
|
||||
self._async_schedule_save(INITIAL_LOAD_SAVE_DELAY)
|
||||
|
||||
@callback
|
||||
def _build_token_id_to_user_id(self) -> None:
|
||||
"""Build a map of token id to user id."""
|
||||
self._token_id_to_user_id = {
|
||||
token_id: user_id
|
||||
for user_id, user in self._users.items()
|
||||
for token_id in user.refresh_tokens
|
||||
}
|
||||
|
||||
@callback
|
||||
def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None:
|
||||
"""Save users."""
|
||||
|
@ -575,6 +588,7 @@ class AuthStore:
|
|||
read_only_group = _system_read_only_group()
|
||||
groups[read_only_group.id] = read_only_group
|
||||
self._groups = groups
|
||||
self._build_token_id_to_user_id()
|
||||
|
||||
|
||||
def _system_admin_group() -> models.Group:
|
||||
|
|
|
@ -305,3 +305,24 @@ async def test_loading_does_not_write_right_away(
|
|||
# Once for the task
|
||||
await hass.async_block_till_done()
|
||||
assert hass_storage[auth_store.STORAGE_KEY] != {}
|
||||
|
||||
|
||||
async def test_add_remove_user_affects_tokens(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
"""Test adding and removing a user removes the tokens."""
|
||||
store = auth_store.AuthStore(hass)
|
||||
await store.async_load()
|
||||
user = await store.async_create_user("Test User")
|
||||
assert user.name == "Test User"
|
||||
refresh_token = await store.async_create_refresh_token(
|
||||
user, "client_id", "access_token_expiration"
|
||||
)
|
||||
assert user.refresh_tokens == {refresh_token.id: refresh_token}
|
||||
assert await store.async_get_user(user.id) == user
|
||||
assert store.async_get_refresh_token(refresh_token.id) == refresh_token
|
||||
assert store.async_get_refresh_token_by_token(refresh_token.token) == refresh_token
|
||||
await store.async_remove_user(user)
|
||||
assert store.async_get_refresh_token(refresh_token.id) is None
|
||||
assert store.async_get_refresh_token_by_token(refresh_token.token) is None
|
||||
assert user.refresh_tokens == {}
|
||||
|
|
Loading…
Reference in New Issue