Refresh the nest authentication token on integration start before invoking the pub/sub subsciber (#138003)
* Refresh the nest authentication token on integration start before invoking the pub/sub subscriber * Apply suggestions from code review --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>pull/138131/head
parent
0bd161a45a
commit
b1f3068b41
|
@ -198,7 +198,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool
|
|||
entry, unique_id=entry.data[CONF_PROJECT_ID]
|
||||
)
|
||||
|
||||
subscriber = await api.new_subscriber(hass, entry)
|
||||
auth = await api.new_auth(hass, entry)
|
||||
try:
|
||||
await auth.async_get_access_token()
|
||||
except AuthException as err:
|
||||
raise ConfigEntryAuthFailed(f"Authentication error: {err!s}") from err
|
||||
except ConfigurationException as err:
|
||||
_LOGGER.error("Configuration error: %s", err)
|
||||
return False
|
||||
|
||||
subscriber = await api.new_subscriber(hass, entry, auth)
|
||||
if not subscriber:
|
||||
return False
|
||||
# Keep media for last N events in memory
|
||||
|
|
|
@ -101,9 +101,7 @@ class AccessTokenAuthImpl(AbstractAuth):
|
|||
)
|
||||
|
||||
|
||||
async def new_subscriber(
|
||||
hass: HomeAssistant, entry: NestConfigEntry
|
||||
) -> GoogleNestSubscriber | None:
|
||||
async def new_auth(hass: HomeAssistant, entry: NestConfigEntry) -> AbstractAuth:
|
||||
"""Create a GoogleNestSubscriber."""
|
||||
implementation = (
|
||||
await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||
|
@ -114,14 +112,22 @@ async def new_subscriber(
|
|||
implementation, config_entry_oauth2_flow.LocalOAuth2Implementation
|
||||
):
|
||||
raise TypeError(f"Unexpected auth implementation {implementation}")
|
||||
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
|
||||
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
|
||||
auth = AsyncConfigEntryAuth(
|
||||
return AsyncConfigEntryAuth(
|
||||
aiohttp_client.async_get_clientsession(hass),
|
||||
config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation),
|
||||
implementation.client_id,
|
||||
implementation.client_secret,
|
||||
)
|
||||
|
||||
|
||||
async def new_subscriber(
|
||||
hass: HomeAssistant,
|
||||
entry: NestConfigEntry,
|
||||
auth: AbstractAuth,
|
||||
) -> GoogleNestSubscriber:
|
||||
"""Create a GoogleNestSubscriber."""
|
||||
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
|
||||
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
|
||||
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name)
|
||||
|
||||
|
||||
|
|
|
@ -89,80 +89,3 @@ async def test_auth(
|
|||
assert creds.client_id == CLIENT_ID
|
||||
assert creds.client_secret == CLIENT_SECRET
|
||||
assert creds.scopes == SDM_SCOPES
|
||||
|
||||
|
||||
# This tests needs to be adjusted to remove lingering tasks
|
||||
@pytest.mark.parametrize("expected_lingering_tasks", [True])
|
||||
@pytest.mark.parametrize(
|
||||
"token_expiration_time",
|
||||
[time.time() - 7 * 86400],
|
||||
ids=["expires-in-past"],
|
||||
)
|
||||
async def test_auth_expired_token(
|
||||
hass: HomeAssistant,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
setup_platform: PlatformSetup,
|
||||
token_expiration_time: float,
|
||||
) -> None:
|
||||
"""Verify behavior of an expired token."""
|
||||
# Prepare a token refresh response
|
||||
aioclient_mock.post(
|
||||
OAUTH2_TOKEN,
|
||||
json={
|
||||
"access_token": FAKE_UPDATED_TOKEN,
|
||||
"expires_at": time.time() + 86400,
|
||||
"expires_in": 86400,
|
||||
},
|
||||
)
|
||||
# Prepare to capture credentials in API request. Empty payloads just mean
|
||||
# no devices or structures are loaded.
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
|
||||
|
||||
# Prepare to capture credentials for Subscriber
|
||||
captured_creds = None
|
||||
|
||||
def async_new_subscriber(
|
||||
credentials: Credentials,
|
||||
) -> Mock:
|
||||
"""Capture credentials for tests."""
|
||||
nonlocal captured_creds
|
||||
captured_creds = credentials
|
||||
return AsyncMock()
|
||||
|
||||
with patch(
|
||||
"google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient",
|
||||
side_effect=async_new_subscriber,
|
||||
) as new_subscriber_mock:
|
||||
await setup_platform()
|
||||
|
||||
calls = aioclient_mock.mock_calls
|
||||
assert len(calls) == 3
|
||||
# Verify refresh token call to get an updated token
|
||||
(method, url, data, headers) = calls[0]
|
||||
assert data == {
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": FAKE_REFRESH_TOKEN,
|
||||
}
|
||||
# Verify API requests are made with the new token
|
||||
(method, url, data, headers) = calls[1]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
||||
(method, url, data, headers) = calls[2]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
||||
|
||||
# The subscriber is created with a token that is expired. Verify that the
|
||||
# credential is expired so the subscriber knows it needs to refresh it.
|
||||
assert len(new_subscriber_mock.mock_calls) == 1
|
||||
assert captured_creds
|
||||
creds = captured_creds
|
||||
assert creds.token == FAKE_TOKEN
|
||||
assert creds.refresh_token == FAKE_REFRESH_TOKEN
|
||||
assert int(dt_util.as_timestamp(creds.expiry)) == int(token_expiration_time)
|
||||
assert not creds.valid
|
||||
assert creds.expired
|
||||
assert creds.token_uri == OAUTH2_TOKEN
|
||||
assert creds.client_id == CLIENT_ID
|
||||
assert creds.client_secret == CLIENT_SECRET
|
||||
assert creds.scopes == SDM_SCOPES
|
||||
|
|
Loading…
Reference in New Issue