Refresh google calendar tokens with invalid expiration times (#69679)

* Refresh google calendar tokens with invalid expiration times

* Update tests/components/google/conftest.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove unnecessary async methods in functions being touched already

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/69722/head
Allen Porter 2022-04-08 20:27:58 -07:00 committed by GitHub
parent b5b514b62f
commit 06d2aeec6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 6 deletions

View File

@ -185,8 +185,14 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass, entry
)
)
assert isinstance(implementation, DeviceAuth)
session = config_entry_oauth2_flow.OAuth2Session(hass, entry, implementation)
# Force a token refresh to fix a bug where tokens were persisted with
# expires_in (relative time delta) and expires_at (absolute time) swapped.
if session.token["expires_at"] >= datetime(2070, 1, 1).timestamp():
session.token["expires_in"] = 0
session.token["expires_at"] = datetime.now().timestamp()
await session.async_ensure_token_valid()
required_scope = hass.data[DOMAIN][DATA_CONFIG][CONF_CALENDAR_ACCESS].scope
if required_scope not in session.token.get("scope", []):
raise ConfigEntryAuthFailed(

View File

@ -132,9 +132,16 @@ async def token_scopes() -> list[str]:
@pytest.fixture
async def creds(token_scopes: list[str]) -> OAuth2Credentials:
def token_expiry() -> datetime.datetime:
"""Expiration time for credentials used in the test."""
return utcnow() + datetime.timedelta(days=7)
@pytest.fixture
def creds(
token_scopes: list[str], token_expiry: datetime.datetime
) -> OAuth2Credentials:
"""Fixture that defines creds used in the test."""
token_expiry = utcnow() + datetime.timedelta(days=7)
return OAuth2Credentials(
access_token="ACCESS_TOKEN",
client_id="client-id",
@ -156,9 +163,16 @@ async def storage() -> YieldFixture[FakeStorage]:
@pytest.fixture
async def config_entry(token_scopes: list[str]) -> MockConfigEntry:
def config_entry_token_expiry(token_expiry: datetime.datetime) -> float:
"""Fixture for token expiration value stored in the config entry."""
return token_expiry.timestamp()
@pytest.fixture
async def config_entry(
token_scopes: list[str], config_entry_token_expiry: float
) -> MockConfigEntry:
"""Fixture to create a config entry for the integration."""
token_expiry = utcnow() + datetime.timedelta(days=7)
return MockConfigEntry(
domain=DOMAIN,
data={
@ -168,7 +182,7 @@ async def config_entry(token_scopes: list[str]) -> MockConfigEntry:
"refresh_token": "REFRESH_TOKEN",
"scope": " ".join(token_scopes),
"token_type": "Bearer",
"expires_at": token_expiry.timestamp(),
"expires_at": config_entry_token_expiry,
},
},
)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
import datetime
import time
from typing import Any
from unittest.mock import Mock, call, patch
@ -29,6 +30,7 @@ from .conftest import (
)
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker
# Typing helpers
HassApi = Callable[[], Awaitable[dict[str, Any]]]
@ -471,3 +473,37 @@ async def test_scan_calendars(
assert state
assert state.name == "Calendar 2"
assert state.state == STATE_OFF
@pytest.mark.parametrize(
"config_entry_token_expiry", [datetime.datetime.max.timestamp() + 1]
)
async def test_invalid_token_expiry_in_config_entry(
hass: HomeAssistant,
component_setup: ComponentSetup,
setup_config_entry: MockConfigEntry,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Exercise case in issue #69623 with invalid token expiration persisted."""
# The token is refreshed and new expiration values are returned
expires_in = 86400
expires_at = time.time() + expires_in
aioclient_mock.post(
"https://oauth2.googleapis.com/token",
json={
"refresh_token": "some-refresh-token",
"access_token": "some-updated-token",
"expires_at": expires_at,
"expires_in": expires_in,
},
)
assert await component_setup()
# Verify token expiration values are updated
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
assert entries[0].state is ConfigEntryState.LOADED
assert entries[0].data["token"]["access_token"] == "some-updated-token"
assert entries[0].data["token"]["expires_in"] == expires_in