Resolve nest pub/sub subscriber token refresh issues (#44686)

pull/44761/head
Allen Porter 2021-01-01 16:51:01 -08:00 committed by GitHub
parent a2ca08905f
commit 321c0a87ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 212 additions and 34 deletions

View File

@ -580,7 +580,6 @@ omit =
homeassistant/components/neato/vacuum.py
homeassistant/components/nederlandse_spoorwegen/sensor.py
homeassistant/components/nello/lock.py
homeassistant/components/nest/api.py
homeassistant/components/nest/legacy/*
homeassistant/components/netatmo/__init__.py
homeassistant/components/netatmo/api.py

View File

@ -30,14 +30,7 @@ from homeassistant.helpers import (
)
from . import api, config_flow
from .const import (
API_URL,
DATA_SDM,
DATA_SUBSCRIBER,
DOMAIN,
OAUTH2_AUTHORIZE,
OAUTH2_TOKEN,
)
from .const import DATA_SDM, DATA_SUBSCRIBER, DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN
from .events import EVENT_NAME_MAP, NEST_EVENT
from .legacy import async_setup_legacy, async_setup_legacy_entry
@ -161,7 +154,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
auth = api.AsyncConfigEntryAuth(
aiohttp_client.async_get_clientsession(hass),
session,
API_URL,
config[CONF_CLIENT_ID],
config[CONF_CLIENT_SECRET],
)
subscriber = GoogleNestSubscriber(
auth, config[CONF_PROJECT_ID], config[CONF_SUBSCRIBER_ID]

View File

@ -1,11 +1,15 @@
"""API for Google Nest Device Access bound to Home Assistant OAuth."""
import datetime
from aiohttp import ClientSession
from google.oauth2.credentials import Credentials
from google_nest_sdm.auth import AbstractAuth
from homeassistant.helpers import config_entry_oauth2_flow
from .const import API_URL, OAUTH2_TOKEN, SDM_SCOPES
# See https://developers.google.com/nest/device-access/registration
@ -16,20 +20,37 @@ class AsyncConfigEntryAuth(AbstractAuth):
self,
websession: ClientSession,
oauth_session: config_entry_oauth2_flow.OAuth2Session,
api_url: str,
client_id: str,
client_secret: str,
):
"""Initialize Google Nest Device Access auth."""
super().__init__(websession, api_url)
super().__init__(websession, API_URL)
self._oauth_session = oauth_session
self._client_id = client_id
self._client_secret = client_secret
async def async_get_access_token(self):
"""Return a valid access token."""
"""Return a valid access token for SDM API."""
if not self._oauth_session.valid_token:
await self._oauth_session.async_ensure_token_valid()
return self._oauth_session.token["access_token"]
async def async_get_creds(self):
"""Return a minimal OAuth credential."""
token = await self.async_get_access_token()
return Credentials(token=token)
"""Return an OAuth credential for Pub/Sub Subscriber."""
# We don't have a way for Home Assistant to refresh creds on behalf
# of the google pub/sub subscriber. Instead, build a full
# Credentials object with enough information for the subscriber to
# handle this on its own. We purposely don't refresh the token here
# even when it is expired to fully hand off this responsibility and
# know it is working at startup (then if not, fail loudly).
token = self._oauth_session.token
creds = Credentials(
token=token["access_token"],
refresh_token=token["refresh_token"],
token_uri=OAUTH2_TOKEN,
client_id=self._client_id,
client_secret=self._client_secret,
scopes=SDM_SCOPES,
)
creds.expiry = datetime.datetime.fromtimestamp(token["expires_at"])
return creds

View File

@ -9,30 +9,45 @@ from google_nest_sdm.event import EventMessage
from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber
from homeassistant.components.nest import DOMAIN
from homeassistant.components.nest.const import SDM_SCOPES
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
PROJECT_ID = "some-project-id"
CLIENT_ID = "some-client-id"
CLIENT_SECRET = "some-client-secret"
CONFIG = {
"nest": {
"client_id": "some-client-id",
"client_secret": "some-client-secret",
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
# Required fields for using SDM API
"project_id": "some-project-id",
"project_id": PROJECT_ID,
"subscriber_id": "projects/example/subscriptions/subscriber-id-9876",
},
}
CONFIG_ENTRY_DATA = {
"sdm": {}, # Indicates new SDM API, not legacy API
"auth_implementation": "local",
"token": {
"expires_at": time.time() + 86400,
"access_token": {
"token": "some-token",
FAKE_TOKEN = "some-token"
FAKE_REFRESH_TOKEN = "some-refresh-token"
def create_config_entry(hass, token_expiration_time=None):
"""Create a ConfigEntry and add it to Home Assistant."""
if token_expiration_time is None:
token_expiration_time = time.time() + 86400
config_entry_data = {
"sdm": {}, # Indicates new SDM API, not legacy API
"auth_implementation": "nest",
"token": {
"access_token": FAKE_TOKEN,
"refresh_token": FAKE_REFRESH_TOKEN,
"scope": " ".join(SDM_SCOPES),
"token_type": "Bearer",
"expires_at": token_expiration_time,
},
},
}
}
MockConfigEntry(domain=DOMAIN, data=config_entry_data).add_to_hass(hass)
class FakeDeviceManager(DeviceManager):
@ -86,7 +101,7 @@ class FakeSubscriber(GoogleNestSubscriber):
async def async_setup_sdm_platform(hass, platform, devices={}, structures={}):
"""Set up the platform and prerequisites."""
MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA).add_to_hass(hass)
create_config_entry(hass)
device_manager = FakeDeviceManager(devices=devices, structures=structures)
subscriber = FakeSubscriber(device_manager)
with patch(

View File

@ -0,0 +1,151 @@
"""Tests for the Nest integration API glue library.
There are two interesting cases to exercise that have different strategies
for token refresh and for testing:
- API based requests, tested using aioclient_mock
- Pub/sub subcriber initialization, intercepted with patch()
The tests below exercise both cases during integration setup.
"""
import time
from unittest.mock import patch
from homeassistant.components.nest import DOMAIN
from homeassistant.components.nest.const import API_URL, OAUTH2_TOKEN, SDM_SCOPES
from homeassistant.setup import async_setup_component
from homeassistant.util import dt
from .common import (
CLIENT_ID,
CLIENT_SECRET,
CONFIG,
FAKE_REFRESH_TOKEN,
FAKE_TOKEN,
PROJECT_ID,
create_config_entry,
)
FAKE_UPDATED_TOKEN = "fake-updated-token"
async def async_setup_sdm(hass):
"""Set up the integration."""
assert await async_setup_component(hass, DOMAIN, CONFIG)
await hass.async_block_till_done()
async def test_auth(hass, aioclient_mock):
"""Exercise authentication library creates valid credentials."""
expiration_time = time.time() + 86400
create_config_entry(hass, expiration_time)
# Prepare to capture credentials in API request. Empty payloads just mean
# no devices or structures are loaded.
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
# Prepare to capture credentials for Subscriber
captured_creds = None
async def async_new_subscriber(creds, subscription_name, loop, async_callback):
"""Capture credentials for tests."""
nonlocal captured_creds
captured_creds = creds
return None # GoogleNestSubscriber
with patch(
"google_nest_sdm.google_nest_subscriber.DefaultSubscriberFactory.async_new_subscriber",
side_effect=async_new_subscriber,
) as new_subscriber_mock:
await async_setup_sdm(hass)
# Verify API requests are made with the correct credentials
calls = aioclient_mock.mock_calls
assert len(calls) == 2
(method, url, data, headers) = calls[0]
assert headers == {"Authorization": f"Bearer {FAKE_TOKEN}"}
(method, url, data, headers) = calls[1]
assert headers == {"Authorization": f"Bearer {FAKE_TOKEN}"}
# Verify the susbcriber was created with the correct credentials
assert len(new_subscriber_mock.mock_calls) == 1
assert captured_creds
creds = captured_creds
assert creds.token == FAKE_TOKEN
assert creds.refresh_token == FAKE_REFRESH_TOKEN
assert int(dt.as_timestamp(creds.expiry)) == int(expiration_time)
assert creds.valid
assert not creds.expired
assert creds.token_uri == OAUTH2_TOKEN
assert creds.client_id == CLIENT_ID
assert creds.client_secret == CLIENT_SECRET
assert creds.scopes == SDM_SCOPES
async def test_auth_expired_token(hass, aioclient_mock):
"""Verify behavior of an expired token."""
expiration_time = time.time() - 86400
create_config_entry(hass, expiration_time)
# Prepare a token refresh response
aioclient_mock.post(
OAUTH2_TOKEN,
json={
"access_token": FAKE_UPDATED_TOKEN,
"expires_at": time.time() + 86400,
"expires_in": 86400,
},
)
# Prepare to capture credentials in API request. Empty payloads just mean
# no devices or structures are loaded.
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
# Prepare to capture credentials for Subscriber
captured_creds = None
async def async_new_subscriber(creds, subscription_name, loop, async_callback):
"""Capture credentials for tests."""
nonlocal captured_creds
captured_creds = creds
return None # GoogleNestSubscriber
with patch(
"google_nest_sdm.google_nest_subscriber.DefaultSubscriberFactory.async_new_subscriber",
side_effect=async_new_subscriber,
) as new_subscriber_mock:
await async_setup_sdm(hass)
calls = aioclient_mock.mock_calls
assert len(calls) == 3
# Verify refresh token call to get an updated token
(method, url, data, headers) = calls[0]
assert data == {
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"grant_type": "refresh_token",
"refresh_token": FAKE_REFRESH_TOKEN,
}
# Verify API requests are made with the new token
(method, url, data, headers) = calls[1]
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
(method, url, data, headers) = calls[2]
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
# The subscriber is created with a token that is expired. Verify that the
# credential is expired so the subscriber knows it needs to refresh it.
assert len(new_subscriber_mock.mock_calls) == 1
assert captured_creds
creds = captured_creds
assert creds.token == FAKE_TOKEN
assert creds.refresh_token == FAKE_REFRESH_TOKEN
assert int(dt.as_timestamp(creds.expiry)) == int(expiration_time)
assert not creds.valid
assert creds.expired
assert creds.token_uri == OAUTH2_TOKEN
assert creds.client_id == CLIENT_ID
assert creds.client_secret == CLIENT_SECRET
assert creds.scopes == SDM_SCOPES

View File

@ -19,9 +19,7 @@ from homeassistant.config_entries import (
)
from homeassistant.setup import async_setup_component
from .common import CONFIG, CONFIG_ENTRY_DATA, async_setup_sdm_platform
from tests.common import MockConfigEntry
from .common import CONFIG, async_setup_sdm_platform, create_config_entry
PLATFORM = "sensor"
@ -39,7 +37,7 @@ async def test_setup_success(hass, caplog):
async def async_setup_sdm(hass, config=CONFIG):
"""Prepare test setup."""
MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA).add_to_hass(hass)
create_config_entry(hass)
with patch(
"homeassistant.helpers.config_entry_oauth2_flow.async_get_config_entry_implementation"
):