Index auth token ids to avoid linear search (#116583)

* Index auth token ids to avoid linear search

* async_remove_refresh_token

* coverage
pull/116891/head
J. Nick Koston 2024-05-05 15:47:26 -05:00 committed by GitHub
parent c8e6292cb7
commit a57f4b8f42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 12 deletions

View File

@ -63,6 +63,7 @@ class AuthStore:
self._store = Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
)
self._token_id_to_user_id: dict[str, str] = {}
async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users."""
@ -136,7 +137,10 @@ class AuthStore:
async def async_remove_user(self, user: models.User) -> None:
"""Remove a user."""
self._users.pop(user.id)
user = self._users.pop(user.id)
for refresh_token_id in user.refresh_tokens:
del self._token_id_to_user_id[refresh_token_id]
user.refresh_tokens.clear()
self._async_schedule_save()
async def async_update_user(
@ -219,7 +223,9 @@ class AuthStore:
kwargs["client_icon"] = client_icon
refresh_token = models.RefreshToken(**kwargs)
user.refresh_tokens[refresh_token.id] = refresh_token
token_id = refresh_token.id
user.refresh_tokens[token_id] = refresh_token
self._token_id_to_user_id[token_id] = user.id
self._async_schedule_save()
return refresh_token
@ -227,19 +233,17 @@ class AuthStore:
@callback
def async_remove_refresh_token(self, refresh_token: models.RefreshToken) -> None:
"""Remove a refresh token."""
for user in self._users.values():
if user.refresh_tokens.pop(refresh_token.id, None):
self._async_schedule_save()
break
refresh_token_id = refresh_token.id
if user_id := self._token_id_to_user_id.get(refresh_token_id):
del self._users[user_id].refresh_tokens[refresh_token_id]
del self._token_id_to_user_id[refresh_token_id]
self._async_schedule_save()
@callback
def async_get_refresh_token(self, token_id: str) -> models.RefreshToken | None:
"""Get refresh token by id."""
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id)
if refresh_token is not None:
return refresh_token
if user_id := self._token_id_to_user_id.get(token_id):
return self._users[user_id].refresh_tokens.get(token_id)
return None
@callback
@ -479,9 +483,18 @@ class AuthStore:
self._groups = groups
self._users = users
self._build_token_id_to_user_id()
self._async_schedule_save(INITIAL_LOAD_SAVE_DELAY)
@callback
def _build_token_id_to_user_id(self) -> None:
"""Build a map of token id to user id."""
self._token_id_to_user_id = {
token_id: user_id
for user_id, user in self._users.items()
for token_id in user.refresh_tokens
}
@callback
def _async_schedule_save(self, delay: float = DEFAULT_SAVE_DELAY) -> None:
"""Save users."""
@ -575,6 +588,7 @@ class AuthStore:
read_only_group = _system_read_only_group()
groups[read_only_group.id] = read_only_group
self._groups = groups
self._build_token_id_to_user_id()
def _system_admin_group() -> models.Group:

View File

@ -305,3 +305,24 @@ async def test_loading_does_not_write_right_away(
# Once for the task
await hass.async_block_till_done()
assert hass_storage[auth_store.STORAGE_KEY] != {}
async def test_add_remove_user_affects_tokens(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
"""Test adding and removing a user removes the tokens."""
store = auth_store.AuthStore(hass)
await store.async_load()
user = await store.async_create_user("Test User")
assert user.name == "Test User"
refresh_token = await store.async_create_refresh_token(
user, "client_id", "access_token_expiration"
)
assert user.refresh_tokens == {refresh_token.id: refresh_token}
assert await store.async_get_user(user.id) == user
assert store.async_get_refresh_token(refresh_token.id) == refresh_token
assert store.async_get_refresh_token_by_token(refresh_token.token) == refresh_token
await store.async_remove_user(user)
assert store.async_get_refresh_token(refresh_token.id) is None
assert store.async_get_refresh_token_by_token(refresh_token.token) is None
assert user.refresh_tokens == {}