Refresh the nest authentication token on integration start before invoking the pub/sub subsciber (#138003)

* Refresh the nest authentication token on integration start before invoking the pub/sub subscriber

* Apply suggestions from code review

---------

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
pull/138131/head
Allen Porter 2025-02-09 09:31:18 -08:00 committed by GitHub
parent 0bd161a45a
commit b1f3068b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 84 deletions

View File

@ -198,7 +198,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool
entry, unique_id=entry.data[CONF_PROJECT_ID]
)
subscriber = await api.new_subscriber(hass, entry)
auth = await api.new_auth(hass, entry)
try:
await auth.async_get_access_token()
except AuthException as err:
raise ConfigEntryAuthFailed(f"Authentication error: {err!s}") from err
except ConfigurationException as err:
_LOGGER.error("Configuration error: %s", err)
return False
subscriber = await api.new_subscriber(hass, entry, auth)
if not subscriber:
return False
# Keep media for last N events in memory

View File

@ -101,9 +101,7 @@ class AccessTokenAuthImpl(AbstractAuth):
)
async def new_subscriber(
hass: HomeAssistant, entry: NestConfigEntry
) -> GoogleNestSubscriber | None:
async def new_auth(hass: HomeAssistant, entry: NestConfigEntry) -> AbstractAuth:
"""Create a GoogleNestSubscriber."""
implementation = (
await config_entry_oauth2_flow.async_get_config_entry_implementation(
@ -114,14 +112,22 @@ async def new_subscriber(
implementation, config_entry_oauth2_flow.LocalOAuth2Implementation
):
raise TypeError(f"Unexpected auth implementation {implementation}")
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
auth = AsyncConfigEntryAuth(
return AsyncConfigEntryAuth(
aiohttp_client.async_get_clientsession(hass),
config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation),
implementation.client_id,
implementation.client_secret,
)
async def new_subscriber(
hass: HomeAssistant,
entry: NestConfigEntry,
auth: AbstractAuth,
) -> GoogleNestSubscriber:
"""Create a GoogleNestSubscriber."""
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name)

View File

@ -89,80 +89,3 @@ async def test_auth(
assert creds.client_id == CLIENT_ID
assert creds.client_secret == CLIENT_SECRET
assert creds.scopes == SDM_SCOPES
# This tests needs to be adjusted to remove lingering tasks
@pytest.mark.parametrize("expected_lingering_tasks", [True])
@pytest.mark.parametrize(
"token_expiration_time",
[time.time() - 7 * 86400],
ids=["expires-in-past"],
)
async def test_auth_expired_token(
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
setup_platform: PlatformSetup,
token_expiration_time: float,
) -> None:
"""Verify behavior of an expired token."""
# Prepare a token refresh response
aioclient_mock.post(
OAUTH2_TOKEN,
json={
"access_token": FAKE_UPDATED_TOKEN,
"expires_at": time.time() + 86400,
"expires_in": 86400,
},
)
# Prepare to capture credentials in API request. Empty payloads just mean
# no devices or structures are loaded.
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
# Prepare to capture credentials for Subscriber
captured_creds = None
def async_new_subscriber(
credentials: Credentials,
) -> Mock:
"""Capture credentials for tests."""
nonlocal captured_creds
captured_creds = credentials
return AsyncMock()
with patch(
"google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient",
side_effect=async_new_subscriber,
) as new_subscriber_mock:
await setup_platform()
calls = aioclient_mock.mock_calls
assert len(calls) == 3
# Verify refresh token call to get an updated token
(method, url, data, headers) = calls[0]
assert data == {
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"grant_type": "refresh_token",
"refresh_token": FAKE_REFRESH_TOKEN,
}
# Verify API requests are made with the new token
(method, url, data, headers) = calls[1]
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
(method, url, data, headers) = calls[2]
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
# The subscriber is created with a token that is expired. Verify that the
# credential is expired so the subscriber knows it needs to refresh it.
assert len(new_subscriber_mock.mock_calls) == 1
assert captured_creds
creds = captured_creds
assert creds.token == FAKE_TOKEN
assert creds.refresh_token == FAKE_REFRESH_TOKEN
assert int(dt_util.as_timestamp(creds.expiry)) == int(token_expiration_time)
assert not creds.valid
assert creds.expired
assert creds.token_uri == OAUTH2_TOKEN
assert creds.client_id == CLIENT_ID
assert creds.client_secret == CLIENT_SECRET
assert creds.scopes == SDM_SCOPES