Add lock to token validity check (#117912)

pull/117936/head^2
Joost Lekkerkerker 2024-05-22 20:10:23 +02:00 committed by GitHub
parent 55c8ef1c7b
commit 0c5296b38f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 6 deletions

View File

@ -10,6 +10,7 @@ from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod from abc import ABC, ABCMeta, abstractmethod
import asyncio import asyncio
from asyncio import Lock
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from http import HTTPStatus from http import HTTPStatus
from json import JSONDecodeError from json import JSONDecodeError
@ -506,6 +507,7 @@ class OAuth2Session:
self.hass = hass self.hass = hass
self.config_entry = config_entry self.config_entry = config_entry
self.implementation = implementation self.implementation = implementation
self._token_lock = Lock()
@property @property
def token(self) -> dict: def token(self) -> dict:
@ -522,14 +524,15 @@ class OAuth2Session:
async def async_ensure_token_valid(self) -> None: async def async_ensure_token_valid(self) -> None:
"""Ensure that the current token is valid.""" """Ensure that the current token is valid."""
if self.valid_token: async with self._token_lock:
return if self.valid_token:
return
new_token = await self.implementation.async_refresh_token(self.token) new_token = await self.implementation.async_refresh_token(self.token)
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, "token": new_token} self.config_entry, data={**self.config_entry.data, "token": new_token}
) )
async def async_request( async def async_request(
self, method: str, url: str, **kwargs: Any self, method: str, url: str, **kwargs: Any