Alexa typing part 1 (#97909)

* Typing part 1

* mypy

* Correct typing for logbook
pull/97991/head
Jan Bouwhuis 2023-08-07 20:36:30 +02:00 committed by GitHub
parent fb12c237ab
commit 40a221c923
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 51 deletions

View File

@ -1,15 +1,16 @@
"""Support for Alexa skill auth."""
import asyncio
from datetime import timedelta
from datetime import datetime, timedelta
from http import HTTPStatus
import json
import logging
from typing import Any
import aiohttp
import async_timeout
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import aiohttp_client
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util
@ -30,24 +31,24 @@ STORAGE_REFRESH_TOKEN = "refresh_token"
class Auth:
"""Handle authentication to send events to Alexa."""
def __init__(self, hass, client_id, client_secret):
def __init__(self, hass: HomeAssistant, client_id: str, client_secret: str) -> None:
"""Initialize the Auth class."""
self.hass = hass
self.client_id = client_id
self.client_secret = client_secret
self._prefs = None
self._store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
self._prefs: dict[str, Any] | None = None
self._store: Store = Store(hass, STORAGE_VERSION, STORAGE_KEY)
self._get_token_lock = asyncio.Lock()
async def async_do_auth(self, accept_grant_code):
async def async_do_auth(self, accept_grant_code: str) -> str | None:
"""Do authentication with an AcceptGrant code."""
# access token not retrieved yet for the first time, so this should
# be an access token request
lwa_params = {
lwa_params: dict[str, str] = {
"grant_type": "authorization_code",
"code": accept_grant_code,
CONF_CLIENT_ID: self.client_id,
@ -61,16 +62,18 @@ class Auth:
return await self._async_request_new_token(lwa_params)
@callback
def async_invalidate_access_token(self):
def async_invalidate_access_token(self) -> None:
"""Invalidate access token."""
assert self._prefs is not None
self._prefs[STORAGE_ACCESS_TOKEN] = None
async def async_get_access_token(self):
async def async_get_access_token(self) -> str | None:
"""Perform access token or token refresh request."""
async with self._get_token_lock:
if self._prefs is None:
await self.async_load_preferences()
assert self._prefs is not None
if self.is_token_valid():
_LOGGER.debug("Token still valid, using it")
return self._prefs[STORAGE_ACCESS_TOKEN]
@ -79,7 +82,7 @@ class Auth:
_LOGGER.debug("Token invalid and no refresh token available")
return None
lwa_params = {
lwa_params: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": self._prefs[STORAGE_REFRESH_TOKEN],
CONF_CLIENT_ID: self.client_id,
@ -90,19 +93,23 @@ class Auth:
return await self._async_request_new_token(lwa_params)
@callback
def is_token_valid(self):
def is_token_valid(self) -> bool:
"""Check if a token is already loaded and if it is still valid."""
assert self._prefs is not None
if not self._prefs[STORAGE_ACCESS_TOKEN]:
return False
expire_time = dt_util.parse_datetime(self._prefs[STORAGE_EXPIRE_TIME])
expire_time: datetime | None = dt_util.parse_datetime(
self._prefs[STORAGE_EXPIRE_TIME]
)
assert expire_time is not None
preemptive_expire_time = expire_time - timedelta(
seconds=PREEMPTIVE_REFRESH_TTL_IN_SECONDS
)
return dt_util.utcnow() < preemptive_expire_time
async def _async_request_new_token(self, lwa_params):
async def _async_request_new_token(self, lwa_params: dict[str, str]) -> str | None:
try:
session = aiohttp_client.async_get_clientsession(self.hass)
async with async_timeout.timeout(10):
@ -127,9 +134,9 @@ class Auth:
response_json = await response.json()
_LOGGER.debug("LWA response body : %s", response_json)
access_token = response_json["access_token"]
refresh_token = response_json["refresh_token"]
expires_in = response_json["expires_in"]
access_token: str = response_json["access_token"]
refresh_token: str = response_json["refresh_token"]
expires_in: int = response_json["expires_in"]
expire_time = dt_util.utcnow() + timedelta(seconds=expires_in)
await self._async_update_preferences(
@ -138,7 +145,7 @@ class Auth:
return access_token
async def async_load_preferences(self):
async def async_load_preferences(self) -> None:
"""Load preferences with stored tokens."""
self._prefs = await self._store.async_load()
@ -149,10 +156,13 @@ class Auth:
STORAGE_EXPIRE_TIME: None,
}
async def _async_update_preferences(self, access_token, refresh_token, expire_time):
async def _async_update_preferences(
self, access_token: str, refresh_token: str, expire_time: str
) -> None:
"""Update user preferences."""
if self._prefs is None:
await self.async_load_preferences()
assert self._prefs is not None
if access_token is not None:
self._prefs[STORAGE_ACCESS_TOKEN] = access_token

View File

@ -4,6 +4,9 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import logging
from typing import Any
from yarl import URL
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.storage import Store
@ -33,38 +36,38 @@ class AbstractConfig(ABC):
await self._store.async_load()
@property
def supports_auth(self):
def supports_auth(self) -> bool:
"""Return if config supports auth."""
return False
@property
def should_report_state(self):
def should_report_state(self) -> bool:
"""Return if states should be proactively reported."""
return False
@property
def endpoint(self):
@abstractmethod
def endpoint(self) -> str | URL | None:
"""Endpoint for report state."""
return None
@property
@abstractmethod
def locale(self):
def locale(self) -> str | None:
"""Return config locale."""
@property
def entity_config(self):
def entity_config(self) -> dict[str, Any]:
"""Return entity config."""
return {}
@property
def is_reporting_states(self):
def is_reporting_states(self) -> bool:
"""Return if proactive mode is enabled."""
return self._unsub_proactive_report is not None
@callback
@abstractmethod
def user_identifier(self):
def user_identifier(self) -> str:
"""Return an identifier for the user that represents this config."""
async def async_enable_proactive_mode(self) -> None:
@ -85,29 +88,29 @@ class AbstractConfig(ABC):
self._unsub_proactive_report = None
@callback
def should_expose(self, entity_id):
def should_expose(self, entity_id: str) -> bool:
"""If an entity should be exposed."""
return False
@callback
def async_invalidate_access_token(self):
def async_invalidate_access_token(self) -> None:
"""Invalidate access token."""
raise NotImplementedError
async def async_get_access_token(self):
async def async_get_access_token(self) -> str | None:
"""Get an access token."""
raise NotImplementedError
async def async_accept_grant(self, code):
async def async_accept_grant(self, code: str) -> str | None:
"""Accept a grant."""
raise NotImplementedError
@property
def authorized(self):
def authorized(self) -> bool:
"""Return authorization status."""
return self._store.authorized
async def set_authorized(self, authorized) -> None:
async def set_authorized(self, authorized: bool) -> None:
"""Set authorization status.
- Set when an incoming message is received from Alexa.
@ -132,25 +135,26 @@ class AlexaConfigStore:
_STORAGE_VERSION = 1
_STORAGE_KEY = DOMAIN
def __init__(self, hass):
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize a configuration store."""
self._data = None
self._data: dict[str, Any] | None = None
self._hass = hass
self._store = Store(hass, self._STORAGE_VERSION, self._STORAGE_KEY)
self._store: Store = Store(hass, self._STORAGE_VERSION, self._STORAGE_KEY)
@property
def authorized(self):
def authorized(self) -> bool:
"""Return authorization status."""
assert self._data is not None
return self._data[STORE_AUTHORIZED]
@callback
def set_authorized(self, authorized):
def set_authorized(self, authorized: bool) -> None:
"""Set authorization status."""
if authorized != self._data[STORE_AUTHORIZED]:
if self._data is not None and authorized != self._data[STORE_AUTHORIZED]:
self._data[STORE_AUTHORIZED] = authorized
self._store.async_delay_save(lambda: self._data, 1.0)
async def async_load(self):
async def async_load(self) -> None:
"""Load saved configuration from disk."""
if data := await self._store.async_load():
self._data = data

View File

@ -69,7 +69,7 @@ API_TEMP_UNITS = {
# Needs to be ordered dict for `async_api_set_thermostat_mode` which does a
# reverse mapping of this dict and we want to map the first occurrence of OFF
# back to HA state.
API_THERMOSTAT_MODES = OrderedDict(
API_THERMOSTAT_MODES: OrderedDict[str, str] = OrderedDict(
[
(climate.HVACMode.HEAT, "HEAT"),
(climate.HVACMode.COOL, "COOL"),

View File

@ -1,8 +1,9 @@
"""Alexa related errors."""
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from .const import API_TEMP_UNITS
@ -29,7 +30,9 @@ class AlexaError(Exception):
namespace: str | None = None
error_type: str | None = None
def __init__(self, error_message, payload=None):
def __init__(
self, error_message: str, payload: dict[str, Any] | None = None
) -> None:
"""Initialize an alexa error."""
Exception.__init__(self)
self.error_message = error_message
@ -42,7 +45,7 @@ class AlexaInvalidEndpointError(AlexaError):
namespace = "Alexa"
error_type = "NO_SUCH_ENDPOINT"
def __init__(self, endpoint_id):
def __init__(self, endpoint_id: str) -> None:
"""Initialize invalid endpoint error."""
msg = f"The endpoint {endpoint_id} does not exist"
AlexaError.__init__(self, msg)
@ -93,7 +96,9 @@ class AlexaTempRangeError(AlexaError):
namespace = "Alexa"
error_type = "TEMPERATURE_VALUE_OUT_OF_RANGE"
def __init__(self, hass, temp, min_temp, max_temp):
def __init__(
self, hass: HomeAssistant, temp: float, min_temp: float, max_temp: float
) -> None:
"""Initialize TempRange error."""
unit = hass.config.units.temperature_unit
temp_range = {

View File

@ -4,10 +4,13 @@ from http import HTTPStatus
import logging
import uuid
from aiohttp.web_response import StreamResponse
from homeassistant.components import http
from homeassistant.const import CONF_PASSWORD
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import template
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util
from .const import (
@ -32,7 +35,7 @@ FLASH_BRIEFINGS_API_ENDPOINT = "/api/alexa/flash_briefings/{briefing_id}"
@callback
def async_setup(hass, flash_briefing_config):
def async_setup(hass: HomeAssistant, flash_briefing_config: ConfigType) -> None:
"""Activate Alexa component."""
hass.http.register_view(AlexaFlashBriefingView(hass, flash_briefing_config))
@ -44,14 +47,16 @@ class AlexaFlashBriefingView(http.HomeAssistantView):
requires_auth = False
name = "api:alexa:flash_briefings"
def __init__(self, hass, flash_briefings):
def __init__(self, hass: HomeAssistant, flash_briefings: ConfigType) -> None:
"""Initialize Alexa view."""
super().__init__()
self.flash_briefings = flash_briefings
template.attach(hass, self.flash_briefings)
@callback
def get(self, request, briefing_id):
def get(
self, request: http.HomeAssistantRequest, briefing_id: str
) -> StreamResponse | tuple[bytes, HTTPStatus]:
"""Handle Alexa Flash Briefing request."""
_LOGGER.debug("Received Alexa flash briefing request for: %s", briefing_id)

View File

@ -1,20 +1,26 @@
"""Describe logbook events."""
from collections.abc import Callable
from typing import Any
from homeassistant.components.logbook import (
LOGBOOK_ENTRY_ENTITY_ID,
LOGBOOK_ENTRY_MESSAGE,
LOGBOOK_ENTRY_NAME,
)
from homeassistant.core import callback
from homeassistant.core import Event, HomeAssistant, callback
from .const import DOMAIN, EVENT_ALEXA_SMART_HOME
@callback
def async_describe_events(hass, async_describe_event):
def async_describe_events(
hass: HomeAssistant,
async_describe_event: Callable[[str, str, Callable[[Event], dict[str, str]]], None],
) -> None:
"""Describe logbook events."""
@callback
def async_describe_logbook_event(event):
def async_describe_logbook_event(event: Event) -> dict[str, Any]:
"""Describe a logbook event."""
data = event.data

View File

@ -416,6 +416,7 @@ async def async_send_add_or_update_message(
message_serialized = message.serialize()
session = async_get_clientsession(hass)
assert config.endpoint is not None
return await session.post(
config.endpoint, headers=headers, json=message_serialized, allow_redirects=True
)
@ -451,6 +452,7 @@ async def async_send_delete_message(
message_serialized = message.serialize()
session = async_get_clientsession(hass)
assert config.endpoint is not None
return await session.post(
config.endpoint, headers=headers, json=message_serialized, allow_redirects=True
)