Refresh the nest authentication token on integration start before invoking the pub/sub subsciber (#138003)
* Refresh the nest authentication token on integration start before invoking the pub/sub subscriber * Apply suggestions from code review --------- Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>pull/138131/head
parent
0bd161a45a
commit
b1f3068b41
|
@ -198,7 +198,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: NestConfigEntry) -> bool
|
||||||
entry, unique_id=entry.data[CONF_PROJECT_ID]
|
entry, unique_id=entry.data[CONF_PROJECT_ID]
|
||||||
)
|
)
|
||||||
|
|
||||||
subscriber = await api.new_subscriber(hass, entry)
|
auth = await api.new_auth(hass, entry)
|
||||||
|
try:
|
||||||
|
await auth.async_get_access_token()
|
||||||
|
except AuthException as err:
|
||||||
|
raise ConfigEntryAuthFailed(f"Authentication error: {err!s}") from err
|
||||||
|
except ConfigurationException as err:
|
||||||
|
_LOGGER.error("Configuration error: %s", err)
|
||||||
|
return False
|
||||||
|
|
||||||
|
subscriber = await api.new_subscriber(hass, entry, auth)
|
||||||
if not subscriber:
|
if not subscriber:
|
||||||
return False
|
return False
|
||||||
# Keep media for last N events in memory
|
# Keep media for last N events in memory
|
||||||
|
|
|
@ -101,9 +101,7 @@ class AccessTokenAuthImpl(AbstractAuth):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def new_subscriber(
|
async def new_auth(hass: HomeAssistant, entry: NestConfigEntry) -> AbstractAuth:
|
||||||
hass: HomeAssistant, entry: NestConfigEntry
|
|
||||||
) -> GoogleNestSubscriber | None:
|
|
||||||
"""Create a GoogleNestSubscriber."""
|
"""Create a GoogleNestSubscriber."""
|
||||||
implementation = (
|
implementation = (
|
||||||
await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||||
|
@ -114,14 +112,22 @@ async def new_subscriber(
|
||||||
implementation, config_entry_oauth2_flow.LocalOAuth2Implementation
|
implementation, config_entry_oauth2_flow.LocalOAuth2Implementation
|
||||||
):
|
):
|
||||||
raise TypeError(f"Unexpected auth implementation {implementation}")
|
raise TypeError(f"Unexpected auth implementation {implementation}")
|
||||||
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
|
return AsyncConfigEntryAuth(
|
||||||
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
|
|
||||||
auth = AsyncConfigEntryAuth(
|
|
||||||
aiohttp_client.async_get_clientsession(hass),
|
aiohttp_client.async_get_clientsession(hass),
|
||||||
config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation),
|
config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation),
|
||||||
implementation.client_id,
|
implementation.client_id,
|
||||||
implementation.client_secret,
|
implementation.client_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def new_subscriber(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entry: NestConfigEntry,
|
||||||
|
auth: AbstractAuth,
|
||||||
|
) -> GoogleNestSubscriber:
|
||||||
|
"""Create a GoogleNestSubscriber."""
|
||||||
|
if (subscription_name := entry.data.get(CONF_SUBSCRIPTION_NAME)) is None:
|
||||||
|
subscription_name = entry.data[CONF_SUBSCRIBER_ID]
|
||||||
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name)
|
return GoogleNestSubscriber(auth, entry.data[CONF_PROJECT_ID], subscription_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -89,80 +89,3 @@ async def test_auth(
|
||||||
assert creds.client_id == CLIENT_ID
|
assert creds.client_id == CLIENT_ID
|
||||||
assert creds.client_secret == CLIENT_SECRET
|
assert creds.client_secret == CLIENT_SECRET
|
||||||
assert creds.scopes == SDM_SCOPES
|
assert creds.scopes == SDM_SCOPES
|
||||||
|
|
||||||
|
|
||||||
# This tests needs to be adjusted to remove lingering tasks
|
|
||||||
@pytest.mark.parametrize("expected_lingering_tasks", [True])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"token_expiration_time",
|
|
||||||
[time.time() - 7 * 86400],
|
|
||||||
ids=["expires-in-past"],
|
|
||||||
)
|
|
||||||
async def test_auth_expired_token(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
aioclient_mock: AiohttpClientMocker,
|
|
||||||
setup_platform: PlatformSetup,
|
|
||||||
token_expiration_time: float,
|
|
||||||
) -> None:
|
|
||||||
"""Verify behavior of an expired token."""
|
|
||||||
# Prepare a token refresh response
|
|
||||||
aioclient_mock.post(
|
|
||||||
OAUTH2_TOKEN,
|
|
||||||
json={
|
|
||||||
"access_token": FAKE_UPDATED_TOKEN,
|
|
||||||
"expires_at": time.time() + 86400,
|
|
||||||
"expires_in": 86400,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Prepare to capture credentials in API request. Empty payloads just mean
|
|
||||||
# no devices or structures are loaded.
|
|
||||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
|
|
||||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
|
|
||||||
|
|
||||||
# Prepare to capture credentials for Subscriber
|
|
||||||
captured_creds = None
|
|
||||||
|
|
||||||
def async_new_subscriber(
|
|
||||||
credentials: Credentials,
|
|
||||||
) -> Mock:
|
|
||||||
"""Capture credentials for tests."""
|
|
||||||
nonlocal captured_creds
|
|
||||||
captured_creds = credentials
|
|
||||||
return AsyncMock()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"google_nest_sdm.subscriber_client.pubsub_v1.SubscriberAsyncClient",
|
|
||||||
side_effect=async_new_subscriber,
|
|
||||||
) as new_subscriber_mock:
|
|
||||||
await setup_platform()
|
|
||||||
|
|
||||||
calls = aioclient_mock.mock_calls
|
|
||||||
assert len(calls) == 3
|
|
||||||
# Verify refresh token call to get an updated token
|
|
||||||
(method, url, data, headers) = calls[0]
|
|
||||||
assert data == {
|
|
||||||
"client_id": CLIENT_ID,
|
|
||||||
"client_secret": CLIENT_SECRET,
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": FAKE_REFRESH_TOKEN,
|
|
||||||
}
|
|
||||||
# Verify API requests are made with the new token
|
|
||||||
(method, url, data, headers) = calls[1]
|
|
||||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
|
||||||
(method, url, data, headers) = calls[2]
|
|
||||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
|
||||||
|
|
||||||
# The subscriber is created with a token that is expired. Verify that the
|
|
||||||
# credential is expired so the subscriber knows it needs to refresh it.
|
|
||||||
assert len(new_subscriber_mock.mock_calls) == 1
|
|
||||||
assert captured_creds
|
|
||||||
creds = captured_creds
|
|
||||||
assert creds.token == FAKE_TOKEN
|
|
||||||
assert creds.refresh_token == FAKE_REFRESH_TOKEN
|
|
||||||
assert int(dt_util.as_timestamp(creds.expiry)) == int(token_expiration_time)
|
|
||||||
assert not creds.valid
|
|
||||||
assert creds.expired
|
|
||||||
assert creds.token_uri == OAUTH2_TOKEN
|
|
||||||
assert creds.client_id == CLIENT_ID
|
|
||||||
assert creds.client_secret == CLIENT_SECRET
|
|
||||||
assert creds.scopes == SDM_SCOPES
|
|
||||||
|
|
Loading…
Reference in New Issue