Resolve nest pub/sub subscriber token refresh issues (#44686)
parent
a2ca08905f
commit
321c0a87ae
|
@ -580,7 +580,6 @@ omit =
|
|||
homeassistant/components/neato/vacuum.py
|
||||
homeassistant/components/nederlandse_spoorwegen/sensor.py
|
||||
homeassistant/components/nello/lock.py
|
||||
homeassistant/components/nest/api.py
|
||||
homeassistant/components/nest/legacy/*
|
||||
homeassistant/components/netatmo/__init__.py
|
||||
homeassistant/components/netatmo/api.py
|
||||
|
|
|
@ -30,14 +30,7 @@ from homeassistant.helpers import (
|
|||
)
|
||||
|
||||
from . import api, config_flow
|
||||
from .const import (
|
||||
API_URL,
|
||||
DATA_SDM,
|
||||
DATA_SUBSCRIBER,
|
||||
DOMAIN,
|
||||
OAUTH2_AUTHORIZE,
|
||||
OAUTH2_TOKEN,
|
||||
)
|
||||
from .const import DATA_SDM, DATA_SUBSCRIBER, DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN
|
||||
from .events import EVENT_NAME_MAP, NEST_EVENT
|
||||
from .legacy import async_setup_legacy, async_setup_legacy_entry
|
||||
|
||||
|
@ -161,7 +154,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
|||
auth = api.AsyncConfigEntryAuth(
|
||||
aiohttp_client.async_get_clientsession(hass),
|
||||
session,
|
||||
API_URL,
|
||||
config[CONF_CLIENT_ID],
|
||||
config[CONF_CLIENT_SECRET],
|
||||
)
|
||||
subscriber = GoogleNestSubscriber(
|
||||
auth, config[CONF_PROJECT_ID], config[CONF_SUBSCRIBER_ID]
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
"""API for Google Nest Device Access bound to Home Assistant OAuth."""
|
||||
|
||||
import datetime
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_nest_sdm.auth import AbstractAuth
|
||||
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from .const import API_URL, OAUTH2_TOKEN, SDM_SCOPES
|
||||
|
||||
# See https://developers.google.com/nest/device-access/registration
|
||||
|
||||
|
||||
|
@ -16,20 +20,37 @@ class AsyncConfigEntryAuth(AbstractAuth):
|
|||
self,
|
||||
websession: ClientSession,
|
||||
oauth_session: config_entry_oauth2_flow.OAuth2Session,
|
||||
api_url: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
):
|
||||
"""Initialize Google Nest Device Access auth."""
|
||||
super().__init__(websession, api_url)
|
||||
super().__init__(websession, API_URL)
|
||||
self._oauth_session = oauth_session
|
||||
self._client_id = client_id
|
||||
self._client_secret = client_secret
|
||||
|
||||
async def async_get_access_token(self):
|
||||
"""Return a valid access token."""
|
||||
"""Return a valid access token for SDM API."""
|
||||
if not self._oauth_session.valid_token:
|
||||
await self._oauth_session.async_ensure_token_valid()
|
||||
|
||||
return self._oauth_session.token["access_token"]
|
||||
|
||||
async def async_get_creds(self):
|
||||
"""Return a minimal OAuth credential."""
|
||||
token = await self.async_get_access_token()
|
||||
return Credentials(token=token)
|
||||
"""Return an OAuth credential for Pub/Sub Subscriber."""
|
||||
# We don't have a way for Home Assistant to refresh creds on behalf
|
||||
# of the google pub/sub subscriber. Instead, build a full
|
||||
# Credentials object with enough information for the subscriber to
|
||||
# handle this on its own. We purposely don't refresh the token here
|
||||
# even when it is expired to fully hand off this responsibility and
|
||||
# know it is working at startup (then if not, fail loudly).
|
||||
token = self._oauth_session.token
|
||||
creds = Credentials(
|
||||
token=token["access_token"],
|
||||
refresh_token=token["refresh_token"],
|
||||
token_uri=OAUTH2_TOKEN,
|
||||
client_id=self._client_id,
|
||||
client_secret=self._client_secret,
|
||||
scopes=SDM_SCOPES,
|
||||
)
|
||||
creds.expiry = datetime.datetime.fromtimestamp(token["expires_at"])
|
||||
return creds
|
||||
|
|
|
@ -9,30 +9,45 @@ from google_nest_sdm.event import EventMessage
|
|||
from google_nest_sdm.google_nest_subscriber import GoogleNestSubscriber
|
||||
|
||||
from homeassistant.components.nest import DOMAIN
|
||||
from homeassistant.components.nest.const import SDM_SCOPES
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
PROJECT_ID = "some-project-id"
|
||||
CLIENT_ID = "some-client-id"
|
||||
CLIENT_SECRET = "some-client-secret"
|
||||
|
||||
CONFIG = {
|
||||
"nest": {
|
||||
"client_id": "some-client-id",
|
||||
"client_secret": "some-client-secret",
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET,
|
||||
# Required fields for using SDM API
|
||||
"project_id": "some-project-id",
|
||||
"project_id": PROJECT_ID,
|
||||
"subscriber_id": "projects/example/subscriptions/subscriber-id-9876",
|
||||
},
|
||||
}
|
||||
|
||||
CONFIG_ENTRY_DATA = {
|
||||
"sdm": {}, # Indicates new SDM API, not legacy API
|
||||
"auth_implementation": "local",
|
||||
"token": {
|
||||
"expires_at": time.time() + 86400,
|
||||
"access_token": {
|
||||
"token": "some-token",
|
||||
FAKE_TOKEN = "some-token"
|
||||
FAKE_REFRESH_TOKEN = "some-refresh-token"
|
||||
|
||||
|
||||
def create_config_entry(hass, token_expiration_time=None):
|
||||
"""Create a ConfigEntry and add it to Home Assistant."""
|
||||
if token_expiration_time is None:
|
||||
token_expiration_time = time.time() + 86400
|
||||
config_entry_data = {
|
||||
"sdm": {}, # Indicates new SDM API, not legacy API
|
||||
"auth_implementation": "nest",
|
||||
"token": {
|
||||
"access_token": FAKE_TOKEN,
|
||||
"refresh_token": FAKE_REFRESH_TOKEN,
|
||||
"scope": " ".join(SDM_SCOPES),
|
||||
"token_type": "Bearer",
|
||||
"expires_at": token_expiration_time,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
MockConfigEntry(domain=DOMAIN, data=config_entry_data).add_to_hass(hass)
|
||||
|
||||
|
||||
class FakeDeviceManager(DeviceManager):
|
||||
|
@ -86,7 +101,7 @@ class FakeSubscriber(GoogleNestSubscriber):
|
|||
|
||||
async def async_setup_sdm_platform(hass, platform, devices={}, structures={}):
|
||||
"""Set up the platform and prerequisites."""
|
||||
MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA).add_to_hass(hass)
|
||||
create_config_entry(hass)
|
||||
device_manager = FakeDeviceManager(devices=devices, structures=structures)
|
||||
subscriber = FakeSubscriber(device_manager)
|
||||
with patch(
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
"""Tests for the Nest integration API glue library.
|
||||
|
||||
There are two interesting cases to exercise that have different strategies
|
||||
for token refresh and for testing:
|
||||
- API based requests, tested using aioclient_mock
|
||||
- Pub/sub subcriber initialization, intercepted with patch()
|
||||
|
||||
The tests below exercise both cases during integration setup.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.nest import DOMAIN
|
||||
from homeassistant.components.nest.const import API_URL, OAUTH2_TOKEN, SDM_SCOPES
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt
|
||||
|
||||
from .common import (
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
CONFIG,
|
||||
FAKE_REFRESH_TOKEN,
|
||||
FAKE_TOKEN,
|
||||
PROJECT_ID,
|
||||
create_config_entry,
|
||||
)
|
||||
|
||||
FAKE_UPDATED_TOKEN = "fake-updated-token"
|
||||
|
||||
|
||||
async def async_setup_sdm(hass):
|
||||
"""Set up the integration."""
|
||||
assert await async_setup_component(hass, DOMAIN, CONFIG)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_auth(hass, aioclient_mock):
|
||||
"""Exercise authentication library creates valid credentials."""
|
||||
|
||||
expiration_time = time.time() + 86400
|
||||
create_config_entry(hass, expiration_time)
|
||||
|
||||
# Prepare to capture credentials in API request. Empty payloads just mean
|
||||
# no devices or structures are loaded.
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
|
||||
|
||||
# Prepare to capture credentials for Subscriber
|
||||
captured_creds = None
|
||||
|
||||
async def async_new_subscriber(creds, subscription_name, loop, async_callback):
|
||||
"""Capture credentials for tests."""
|
||||
nonlocal captured_creds
|
||||
captured_creds = creds
|
||||
return None # GoogleNestSubscriber
|
||||
|
||||
with patch(
|
||||
"google_nest_sdm.google_nest_subscriber.DefaultSubscriberFactory.async_new_subscriber",
|
||||
side_effect=async_new_subscriber,
|
||||
) as new_subscriber_mock:
|
||||
await async_setup_sdm(hass)
|
||||
|
||||
# Verify API requests are made with the correct credentials
|
||||
calls = aioclient_mock.mock_calls
|
||||
assert len(calls) == 2
|
||||
(method, url, data, headers) = calls[0]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_TOKEN}"}
|
||||
(method, url, data, headers) = calls[1]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_TOKEN}"}
|
||||
|
||||
# Verify the susbcriber was created with the correct credentials
|
||||
assert len(new_subscriber_mock.mock_calls) == 1
|
||||
assert captured_creds
|
||||
creds = captured_creds
|
||||
assert creds.token == FAKE_TOKEN
|
||||
assert creds.refresh_token == FAKE_REFRESH_TOKEN
|
||||
assert int(dt.as_timestamp(creds.expiry)) == int(expiration_time)
|
||||
assert creds.valid
|
||||
assert not creds.expired
|
||||
assert creds.token_uri == OAUTH2_TOKEN
|
||||
assert creds.client_id == CLIENT_ID
|
||||
assert creds.client_secret == CLIENT_SECRET
|
||||
assert creds.scopes == SDM_SCOPES
|
||||
|
||||
|
||||
async def test_auth_expired_token(hass, aioclient_mock):
|
||||
"""Verify behavior of an expired token."""
|
||||
|
||||
expiration_time = time.time() - 86400
|
||||
create_config_entry(hass, expiration_time)
|
||||
|
||||
# Prepare a token refresh response
|
||||
aioclient_mock.post(
|
||||
OAUTH2_TOKEN,
|
||||
json={
|
||||
"access_token": FAKE_UPDATED_TOKEN,
|
||||
"expires_at": time.time() + 86400,
|
||||
"expires_in": 86400,
|
||||
},
|
||||
)
|
||||
# Prepare to capture credentials in API request. Empty payloads just mean
|
||||
# no devices or structures are loaded.
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/structures", json={})
|
||||
aioclient_mock.get(f"{API_URL}/enterprises/{PROJECT_ID}/devices", json={})
|
||||
|
||||
# Prepare to capture credentials for Subscriber
|
||||
captured_creds = None
|
||||
|
||||
async def async_new_subscriber(creds, subscription_name, loop, async_callback):
|
||||
"""Capture credentials for tests."""
|
||||
nonlocal captured_creds
|
||||
captured_creds = creds
|
||||
return None # GoogleNestSubscriber
|
||||
|
||||
with patch(
|
||||
"google_nest_sdm.google_nest_subscriber.DefaultSubscriberFactory.async_new_subscriber",
|
||||
side_effect=async_new_subscriber,
|
||||
) as new_subscriber_mock:
|
||||
await async_setup_sdm(hass)
|
||||
|
||||
calls = aioclient_mock.mock_calls
|
||||
assert len(calls) == 3
|
||||
# Verify refresh token call to get an updated token
|
||||
(method, url, data, headers) = calls[0]
|
||||
assert data == {
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": FAKE_REFRESH_TOKEN,
|
||||
}
|
||||
# Verify API requests are made with the new token
|
||||
(method, url, data, headers) = calls[1]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
||||
(method, url, data, headers) = calls[2]
|
||||
assert headers == {"Authorization": f"Bearer {FAKE_UPDATED_TOKEN}"}
|
||||
|
||||
# The subscriber is created with a token that is expired. Verify that the
|
||||
# credential is expired so the subscriber knows it needs to refresh it.
|
||||
assert len(new_subscriber_mock.mock_calls) == 1
|
||||
assert captured_creds
|
||||
creds = captured_creds
|
||||
assert creds.token == FAKE_TOKEN
|
||||
assert creds.refresh_token == FAKE_REFRESH_TOKEN
|
||||
assert int(dt.as_timestamp(creds.expiry)) == int(expiration_time)
|
||||
assert not creds.valid
|
||||
assert creds.expired
|
||||
assert creds.token_uri == OAUTH2_TOKEN
|
||||
assert creds.client_id == CLIENT_ID
|
||||
assert creds.client_secret == CLIENT_SECRET
|
||||
assert creds.scopes == SDM_SCOPES
|
|
@ -19,9 +19,7 @@ from homeassistant.config_entries import (
|
|||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import CONFIG, CONFIG_ENTRY_DATA, async_setup_sdm_platform
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from .common import CONFIG, async_setup_sdm_platform, create_config_entry
|
||||
|
||||
PLATFORM = "sensor"
|
||||
|
||||
|
@ -39,7 +37,7 @@ async def test_setup_success(hass, caplog):
|
|||
|
||||
async def async_setup_sdm(hass, config=CONFIG):
|
||||
"""Prepare test setup."""
|
||||
MockConfigEntry(domain=DOMAIN, data=CONFIG_ENTRY_DATA).add_to_hass(hass)
|
||||
create_config_entry(hass)
|
||||
with patch(
|
||||
"homeassistant.helpers.config_entry_oauth2_flow.async_get_config_entry_implementation"
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue