Update typing 03 (#48015)

pull/48062/head
Marc Mueller 2021-03-17 21:46:07 +01:00 committed by GitHub
parent 6fb2e63e49
commit fabd73f08b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 417 additions and 379 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
from typing import Any, Dict, List, Optional, Tuple, cast from typing import Any, Dict, Optional, Tuple, cast
import jwt import jwt
@ -36,8 +36,8 @@ class InvalidProvider(Exception):
async def auth_manager_from_config( async def auth_manager_from_config(
hass: HomeAssistant, hass: HomeAssistant,
provider_configs: List[Dict[str, Any]], provider_configs: list[dict[str, Any]],
module_configs: List[Dict[str, Any]], module_configs: list[dict[str, Any]],
) -> AuthManager: ) -> AuthManager:
"""Initialize an auth manager from config. """Initialize an auth manager from config.
@ -87,8 +87,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
self, self,
handler_key: Any, handler_key: Any,
*, *,
context: Optional[Dict[str, Any]] = None, context: dict[str, Any] | None = None,
data: Optional[Dict[str, Any]] = None, data: dict[str, Any] | None = None,
) -> data_entry_flow.FlowHandler: ) -> data_entry_flow.FlowHandler:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key) auth_provider = self.auth_manager.get_auth_provider(*handler_key)
@ -97,8 +97,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context) return await auth_provider.async_login_flow(context)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any] self, flow: data_entry_flow.FlowHandler, result: dict[str, Any]
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
flow = cast(LoginFlow, flow) flow = cast(LoginFlow, flow)
@ -157,22 +157,22 @@ class AuthManager:
self.login_flow = AuthManagerFlowManager(hass, self) self.login_flow = AuthManagerFlowManager(hass, self)
@property @property
def auth_providers(self) -> List[AuthProvider]: def auth_providers(self) -> list[AuthProvider]:
"""Return a list of available auth providers.""" """Return a list of available auth providers."""
return list(self._providers.values()) return list(self._providers.values())
@property @property
def auth_mfa_modules(self) -> List[MultiFactorAuthModule]: def auth_mfa_modules(self) -> list[MultiFactorAuthModule]:
"""Return a list of available auth modules.""" """Return a list of available auth modules."""
return list(self._mfa_modules.values()) return list(self._mfa_modules.values())
def get_auth_provider( def get_auth_provider(
self, provider_type: str, provider_id: Optional[str] self, provider_type: str, provider_id: str | None
) -> Optional[AuthProvider]: ) -> AuthProvider | None:
"""Return an auth provider, None if not found.""" """Return an auth provider, None if not found."""
return self._providers.get((provider_type, provider_id)) return self._providers.get((provider_type, provider_id))
def get_auth_providers(self, provider_type: str) -> List[AuthProvider]: def get_auth_providers(self, provider_type: str) -> list[AuthProvider]:
"""Return a List of auth provider of one type, Empty if not found.""" """Return a List of auth provider of one type, Empty if not found."""
return [ return [
provider provider
@ -180,30 +180,30 @@ class AuthManager:
if p_type == provider_type if p_type == provider_type
] ]
def get_auth_mfa_module(self, module_id: str) -> Optional[MultiFactorAuthModule]: def get_auth_mfa_module(self, module_id: str) -> MultiFactorAuthModule | None:
"""Return a multi-factor auth module, None if not found.""" """Return a multi-factor auth module, None if not found."""
return self._mfa_modules.get(module_id) return self._mfa_modules.get(module_id)
async def async_get_users(self) -> List[models.User]: async def async_get_users(self) -> list[models.User]:
"""Retrieve all users.""" """Retrieve all users."""
return await self._store.async_get_users() return await self._store.async_get_users()
async def async_get_user(self, user_id: str) -> Optional[models.User]: async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user.""" """Retrieve a user."""
return await self._store.async_get_user(user_id) return await self._store.async_get_user(user_id)
async def async_get_owner(self) -> Optional[models.User]: async def async_get_owner(self) -> models.User | None:
"""Retrieve the owner.""" """Retrieve the owner."""
users = await self.async_get_users() users = await self.async_get_users()
return next((user for user in users if user.is_owner), None) return next((user for user in users if user.is_owner), None)
async def async_get_group(self, group_id: str) -> Optional[models.Group]: async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all groups.""" """Retrieve all groups."""
return await self._store.async_get_group(group_id) return await self._store.async_get_group(group_id)
async def async_get_user_by_credentials( async def async_get_user_by_credentials(
self, credentials: models.Credentials self, credentials: models.Credentials
) -> Optional[models.User]: ) -> models.User | None:
"""Get a user by credential, return None if not found.""" """Get a user by credential, return None if not found."""
for user in await self.async_get_users(): for user in await self.async_get_users():
for creds in user.credentials: for creds in user.credentials:
@ -213,7 +213,7 @@ class AuthManager:
return None return None
async def async_create_system_user( async def async_create_system_user(
self, name: str, group_ids: Optional[List[str]] = None self, name: str, group_ids: list[str] | None = None
) -> models.User: ) -> models.User:
"""Create a system user.""" """Create a system user."""
user = await self._store.async_create_user( user = await self._store.async_create_user(
@ -225,10 +225,10 @@ class AuthManager:
return user return user
async def async_create_user( async def async_create_user(
self, name: str, group_ids: Optional[List[str]] = None self, name: str, group_ids: list[str] | None = None
) -> models.User: ) -> models.User:
"""Create a user.""" """Create a user."""
kwargs: Dict[str, Any] = { kwargs: dict[str, Any] = {
"name": name, "name": name,
"is_active": True, "is_active": True,
"group_ids": group_ids or [], "group_ids": group_ids or [],
@ -294,12 +294,12 @@ class AuthManager:
async def async_update_user( async def async_update_user(
self, self,
user: models.User, user: models.User,
name: Optional[str] = None, name: str | None = None,
is_active: Optional[bool] = None, is_active: bool | None = None,
group_ids: Optional[List[str]] = None, group_ids: list[str] | None = None,
) -> None: ) -> None:
"""Update a user.""" """Update a user."""
kwargs: Dict[str, Any] = {} kwargs: dict[str, Any] = {}
if name is not None: if name is not None:
kwargs["name"] = name kwargs["name"] = name
if group_ids is not None: if group_ids is not None:
@ -362,9 +362,9 @@ class AuthManager:
await module.async_depose_user(user.id) await module.async_depose_user(user.id)
async def async_get_enabled_mfa(self, user: models.User) -> Dict[str, str]: async def async_get_enabled_mfa(self, user: models.User) -> dict[str, str]:
"""List enabled mfa modules for user.""" """List enabled mfa modules for user."""
modules: Dict[str, str] = OrderedDict() modules: dict[str, str] = OrderedDict()
for module_id, module in self._mfa_modules.items(): for module_id, module in self._mfa_modules.items():
if await module.async_is_user_setup(user.id): if await module.async_is_user_setup(user.id):
modules[module_id] = module.name modules[module_id] = module.name
@ -373,12 +373,12 @@ class AuthManager:
async def async_create_refresh_token( async def async_create_refresh_token(
self, self,
user: models.User, user: models.User,
client_id: Optional[str] = None, client_id: str | None = None,
client_name: Optional[str] = None, client_name: str | None = None,
client_icon: Optional[str] = None, client_icon: str | None = None,
token_type: Optional[str] = None, token_type: str | None = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION, access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None, credential: models.Credentials | None = None,
) -> models.RefreshToken: ) -> models.RefreshToken:
"""Create a new refresh token for a user.""" """Create a new refresh token for a user."""
if not user.is_active: if not user.is_active:
@ -432,13 +432,13 @@ class AuthManager:
async def async_get_refresh_token( async def async_get_refresh_token(
self, token_id: str self, token_id: str
) -> Optional[models.RefreshToken]: ) -> models.RefreshToken | None:
"""Get refresh token by id.""" """Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id) return await self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token( async def async_get_refresh_token_by_token(
self, token: str self, token: str
) -> Optional[models.RefreshToken]: ) -> models.RefreshToken | None:
"""Get refresh token by token.""" """Get refresh token by token."""
return await self._store.async_get_refresh_token_by_token(token) return await self._store.async_get_refresh_token_by_token(token)
@ -450,7 +450,7 @@ class AuthManager:
@callback @callback
def async_create_access_token( def async_create_access_token(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> str: ) -> str:
"""Create a new access token.""" """Create a new access token."""
self.async_validate_refresh_token(refresh_token, remote_ip) self.async_validate_refresh_token(refresh_token, remote_ip)
@ -471,7 +471,7 @@ class AuthManager:
@callback @callback
def _async_resolve_provider( def _async_resolve_provider(
self, refresh_token: models.RefreshToken self, refresh_token: models.RefreshToken
) -> Optional[AuthProvider]: ) -> AuthProvider | None:
"""Get the auth provider for the given refresh token. """Get the auth provider for the given refresh token.
Raises an exception if the expected provider is no longer available or return Raises an exception if the expected provider is no longer available or return
@ -492,7 +492,7 @@ class AuthManager:
@callback @callback
def async_validate_refresh_token( def async_validate_refresh_token(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> None: ) -> None:
"""Validate that a refresh token is usable. """Validate that a refresh token is usable.
@ -504,7 +504,7 @@ class AuthManager:
async def async_validate_access_token( async def async_validate_access_token(
self, token: str self, token: str
) -> Optional[models.RefreshToken]: ) -> models.RefreshToken | None:
"""Return refresh token if an access token is valid.""" """Return refresh token if an access token is valid."""
try: try:
unverif_claims = jwt.decode(token, verify=False) unverif_claims = jwt.decode(token, verify=False)
@ -535,7 +535,7 @@ class AuthManager:
@callback @callback
def _async_get_auth_provider( def _async_get_auth_provider(
self, credentials: models.Credentials self, credentials: models.Credentials
) -> Optional[AuthProvider]: ) -> AuthProvider | None:
"""Get auth provider from a set of credentials.""" """Get auth provider from a set of credentials."""
auth_provider_key = ( auth_provider_key = (
credentials.auth_provider_type, credentials.auth_provider_type,

View File

@ -1,10 +1,12 @@
"""Storage for auth models.""" """Storage for auth models."""
from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
import hmac import hmac
from logging import getLogger from logging import getLogger
from typing import Any, Dict, List, Optional from typing import Any
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -34,15 +36,15 @@ class AuthStore:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the auth store.""" """Initialize the auth store."""
self.hass = hass self.hass = hass
self._users: Optional[Dict[str, models.User]] = None self._users: dict[str, models.User] | None = None
self._groups: Optional[Dict[str, models.Group]] = None self._groups: dict[str, models.Group] | None = None
self._perm_lookup: Optional[PermissionLookup] = None self._perm_lookup: PermissionLookup | None = None
self._store = hass.helpers.storage.Store( self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True
) )
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
async def async_get_groups(self) -> List[models.Group]: async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users.""" """Retrieve all users."""
if self._groups is None: if self._groups is None:
await self._async_load() await self._async_load()
@ -50,7 +52,7 @@ class AuthStore:
return list(self._groups.values()) return list(self._groups.values())
async def async_get_group(self, group_id: str) -> Optional[models.Group]: async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all users.""" """Retrieve all users."""
if self._groups is None: if self._groups is None:
await self._async_load() await self._async_load()
@ -58,7 +60,7 @@ class AuthStore:
return self._groups.get(group_id) return self._groups.get(group_id)
async def async_get_users(self) -> List[models.User]: async def async_get_users(self) -> list[models.User]:
"""Retrieve all users.""" """Retrieve all users."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -66,7 +68,7 @@ class AuthStore:
return list(self._users.values()) return list(self._users.values())
async def async_get_user(self, user_id: str) -> Optional[models.User]: async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user by id.""" """Retrieve a user by id."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -76,12 +78,12 @@ class AuthStore:
async def async_create_user( async def async_create_user(
self, self,
name: Optional[str], name: str | None,
is_owner: Optional[bool] = None, is_owner: bool | None = None,
is_active: Optional[bool] = None, is_active: bool | None = None,
system_generated: Optional[bool] = None, system_generated: bool | None = None,
credentials: Optional[models.Credentials] = None, credentials: models.Credentials | None = None,
group_ids: Optional[List[str]] = None, group_ids: list[str] | None = None,
) -> models.User: ) -> models.User:
"""Create a new user.""" """Create a new user."""
if self._users is None: if self._users is None:
@ -97,7 +99,7 @@ class AuthStore:
raise ValueError(f"Invalid group specified {group_id}") raise ValueError(f"Invalid group specified {group_id}")
groups.append(group) groups.append(group)
kwargs: Dict[str, Any] = { kwargs: dict[str, Any] = {
"name": name, "name": name,
# Until we get group management, we just put everyone in the # Until we get group management, we just put everyone in the
# same group. # same group.
@ -146,9 +148,9 @@ class AuthStore:
async def async_update_user( async def async_update_user(
self, self,
user: models.User, user: models.User,
name: Optional[str] = None, name: str | None = None,
is_active: Optional[bool] = None, is_active: bool | None = None,
group_ids: Optional[List[str]] = None, group_ids: list[str] | None = None,
) -> None: ) -> None:
"""Update a user.""" """Update a user."""
assert self._groups is not None assert self._groups is not None
@ -203,15 +205,15 @@ class AuthStore:
async def async_create_refresh_token( async def async_create_refresh_token(
self, self,
user: models.User, user: models.User,
client_id: Optional[str] = None, client_id: str | None = None,
client_name: Optional[str] = None, client_name: str | None = None,
client_icon: Optional[str] = None, client_icon: str | None = None,
token_type: str = models.TOKEN_TYPE_NORMAL, token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION, access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None, credential: models.Credentials | None = None,
) -> models.RefreshToken: ) -> models.RefreshToken:
"""Create a new token for a user.""" """Create a new token for a user."""
kwargs: Dict[str, Any] = { kwargs: dict[str, Any] = {
"user": user, "user": user,
"client_id": client_id, "client_id": client_id,
"token_type": token_type, "token_type": token_type,
@ -244,7 +246,7 @@ class AuthStore:
async def async_get_refresh_token( async def async_get_refresh_token(
self, token_id: str self, token_id: str
) -> Optional[models.RefreshToken]: ) -> models.RefreshToken | None:
"""Get refresh token by id.""" """Get refresh token by id."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -259,7 +261,7 @@ class AuthStore:
async def async_get_refresh_token_by_token( async def async_get_refresh_token_by_token(
self, token: str self, token: str
) -> Optional[models.RefreshToken]: ) -> models.RefreshToken | None:
"""Get refresh token by token.""" """Get refresh token by token."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -276,7 +278,7 @@ class AuthStore:
@callback @callback
def async_log_refresh_token_usage( def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> None: ) -> None:
"""Update refresh token last used information.""" """Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow() refresh_token.last_used_at = dt_util.utcnow()
@ -309,9 +311,9 @@ class AuthStore:
self._set_defaults() self._set_defaults()
return return
users: Dict[str, models.User] = OrderedDict() users: dict[str, models.User] = OrderedDict()
groups: Dict[str, models.Group] = OrderedDict() groups: dict[str, models.Group] = OrderedDict()
credentials: Dict[str, models.Credentials] = OrderedDict() credentials: dict[str, models.Credentials] = OrderedDict()
# Soft-migrating data as we load. We are going to make sure we have a # Soft-migrating data as we load. We are going to make sure we have a
# read only group and an admin group. There are two states that we can # read only group and an admin group. There are two states that we can
@ -328,7 +330,7 @@ class AuthStore:
# was added. # was added.
for group_dict in data.get("groups", []): for group_dict in data.get("groups", []):
policy: Optional[PolicyType] = None policy: PolicyType | None = None
if group_dict["id"] == GROUP_ID_ADMIN: if group_dict["id"] == GROUP_ID_ADMIN:
has_admin_group = True has_admin_group = True
@ -489,7 +491,7 @@ class AuthStore:
self._store.async_delay_save(self._data_to_save, 1) self._store.async_delay_save(self._data_to_save, 1)
@callback @callback
def _data_to_save(self) -> Dict: def _data_to_save(self) -> dict:
"""Return the data to store.""" """Return the data to store."""
assert self._users is not None assert self._users is not None
assert self._groups is not None assert self._groups is not None
@ -508,7 +510,7 @@ class AuthStore:
groups = [] groups = []
for group in self._groups.values(): for group in self._groups.values():
g_dict: Dict[str, Any] = { g_dict: dict[str, Any] = {
"id": group.id, "id": group.id,
# Name not read for sys groups. Kept here for backwards compat # Name not read for sys groups. Kept here for backwards compat
"name": group.name, "name": group.name,
@ -567,7 +569,7 @@ class AuthStore:
"""Set default values for auth store.""" """Set default values for auth store."""
self._users = OrderedDict() self._users = OrderedDict()
groups: Dict[str, models.Group] = OrderedDict() groups: dict[str, models.Group] = OrderedDict()
admin_group = _system_admin_group() admin_group = _system_admin_group()
groups[admin_group.id] = admin_group groups[admin_group.id] = admin_group
user_group = _system_user_group() user_group = _system_user_group()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import importlib import importlib
import logging import logging
import types import types
from typing import Any, Dict, Optional from typing import Any
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -38,7 +38,7 @@ class MultiFactorAuthModule:
DEFAULT_TITLE = "Unnamed auth module" DEFAULT_TITLE = "Unnamed auth module"
MAX_RETRY_TIME = 3 MAX_RETRY_TIME = 3
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize an auth module.""" """Initialize an auth module."""
self.hass = hass self.hass = hass
self.config = config self.config = config
@ -87,7 +87,7 @@ class MultiFactorAuthModule:
"""Return whether user is setup.""" """Return whether user is setup."""
raise NotImplementedError raise NotImplementedError
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool: async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
raise NotImplementedError raise NotImplementedError
@ -104,14 +104,14 @@ class SetupFlow(data_entry_flow.FlowHandler):
self._user_id = user_id self._user_id = user_id
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the first step of setup flow. """Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
Return self.async_create_entry(data={'result': result}) if finish. Return self.async_create_entry(data={'result': result}) if finish.
""" """
errors: Dict[str, str] = {} errors: dict[str, str] = {}
if user_input: if user_input:
result = await self._auth_module.async_setup_user(self._user_id, user_input) result = await self._auth_module.async_setup_user(self._user_id, user_input)
@ -125,7 +125,7 @@ class SetupFlow(data_entry_flow.FlowHandler):
async def auth_mfa_module_from_config( async def auth_mfa_module_from_config(
hass: HomeAssistant, config: Dict[str, Any] hass: HomeAssistant, config: dict[str, Any]
) -> MultiFactorAuthModule: ) -> MultiFactorAuthModule:
"""Initialize an auth module from a config.""" """Initialize an auth module from a config."""
module_name = config[CONF_TYPE] module_name = config[CONF_TYPE]

View File

@ -1,5 +1,7 @@
"""Example auth module.""" """Example auth module."""
from typing import Any, Dict from __future__ import annotations
from typing import Any
import voluptuous as vol import voluptuous as vol
@ -28,7 +30,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Insecure Personal Identify Number" DEFAULT_TITLE = "Insecure Personal Identify Number"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._data = config["data"] self._data = config["data"]
@ -80,7 +82,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
return True return True
return False return False
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool: async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
for data in self._data: for data in self._data:
if data["user_id"] == user_id: if data["user_id"] == user_id:

View File

@ -2,10 +2,12 @@
Sending HOTP through notify service Sending HOTP through notify service
""" """
from __future__ import annotations
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -79,8 +81,8 @@ class NotifySetting:
secret: str = attr.ib(factory=_generate_secret) # not persistent secret: str = attr.ib(factory=_generate_secret) # not persistent
counter: int = attr.ib(factory=_generate_random) # not persistent counter: int = attr.ib(factory=_generate_random) # not persistent
notify_service: Optional[str] = attr.ib(default=None) notify_service: str | None = attr.ib(default=None)
target: Optional[str] = attr.ib(default=None) target: str | None = attr.ib(default=None)
_UsersDict = Dict[str, NotifySetting] _UsersDict = Dict[str, NotifySetting]
@ -92,10 +94,10 @@ class NotifyAuthModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Notify One-Time Password" DEFAULT_TITLE = "Notify One-Time Password"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._user_settings: Optional[_UsersDict] = None self._user_settings: _UsersDict | None = None
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True
) )
@ -146,7 +148,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
) )
@callback @callback
def aync_get_available_notify_services(self) -> List[str]: def aync_get_available_notify_services(self) -> list[str]:
"""Return list of notify services.""" """Return list of notify services."""
unordered_services = set() unordered_services = set()
@ -198,7 +200,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
return user_id in self._user_settings return user_id in self._user_settings
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool: async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
if self._user_settings is None: if self._user_settings is None:
await self._async_load() await self._async_load()
@ -258,7 +260,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
) )
async def async_notify( async def async_notify(
self, code: str, notify_service: str, target: Optional[str] = None self, code: str, notify_service: str, target: str | None = None
) -> None: ) -> None:
"""Send code by notify service.""" """Send code by notify service."""
data = {"message": self._message_template.format(code)} data = {"message": self._message_template.format(code)}
@ -276,23 +278,23 @@ class NotifySetupFlow(SetupFlow):
auth_module: NotifyAuthModule, auth_module: NotifyAuthModule,
setup_schema: vol.Schema, setup_schema: vol.Schema,
user_id: str, user_id: str,
available_notify_services: List[str], available_notify_services: list[str],
) -> None: ) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id) super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint # to fix typing complaint
self._auth_module: NotifyAuthModule = auth_module self._auth_module: NotifyAuthModule = auth_module
self._available_notify_services = available_notify_services self._available_notify_services = available_notify_services
self._secret: Optional[str] = None self._secret: str | None = None
self._count: Optional[int] = None self._count: int | None = None
self._notify_service: Optional[str] = None self._notify_service: str | None = None
self._target: Optional[str] = None self._target: str | None = None
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Let user select available notify services.""" """Let user select available notify services."""
errors: Dict[str, str] = {} errors: dict[str, str] = {}
hass = self._auth_module.hass hass = self._auth_module.hass
if user_input: if user_input:
@ -306,7 +308,7 @@ class NotifySetupFlow(SetupFlow):
if not self._available_notify_services: if not self._available_notify_services:
return self.async_abort(reason="no_available_service") return self.async_abort(reason="no_available_service")
schema: Dict[str, Any] = OrderedDict() schema: dict[str, Any] = OrderedDict()
schema["notify_service"] = vol.In(self._available_notify_services) schema["notify_service"] = vol.In(self._available_notify_services)
schema["target"] = vol.Optional(str) schema["target"] = vol.Optional(str)
@ -315,10 +317,10 @@ class NotifySetupFlow(SetupFlow):
) )
async def async_step_setup( async def async_step_setup(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Verify user can receive one-time password.""" """Verify user can receive one-time password."""
errors: Dict[str, str] = {} errors: dict[str, str] = {}
hass = self._auth_module.hass hass = self._auth_module.hass
if user_input: if user_input:

View File

@ -1,7 +1,9 @@
"""Time-based One Time Password auth module.""" """Time-based One Time Password auth module."""
from __future__ import annotations
import asyncio import asyncio
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Optional, Tuple from typing import Any
import voluptuous as vol import voluptuous as vol
@ -50,7 +52,7 @@ def _generate_qr_code(data: str) -> str:
) )
def _generate_secret_and_qr_code(username: str) -> Tuple[str, str, str]: def _generate_secret_and_qr_code(username: str) -> tuple[str, str, str]:
"""Generate a secret, url, and QR code.""" """Generate a secret, url, and QR code."""
import pyotp # pylint: disable=import-outside-toplevel import pyotp # pylint: disable=import-outside-toplevel
@ -69,10 +71,10 @@ class TotpAuthModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Time-based One Time Password" DEFAULT_TITLE = "Time-based One Time Password"
MAX_RETRY_TIME = 5 MAX_RETRY_TIME = 5
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._users: Optional[Dict[str, str]] = None self._users: dict[str, str] | None = None
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True
) )
@ -100,7 +102,7 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Save data.""" """Save data."""
await self._user_store.async_save({STORAGE_USERS: self._users}) await self._user_store.async_save({STORAGE_USERS: self._users})
def _add_ota_secret(self, user_id: str, secret: Optional[str] = None) -> str: def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
"""Create a ota_secret for user.""" """Create a ota_secret for user."""
import pyotp # pylint: disable=import-outside-toplevel import pyotp # pylint: disable=import-outside-toplevel
@ -145,7 +147,7 @@ class TotpAuthModule(MultiFactorAuthModule):
return user_id in self._users # type: ignore return user_id in self._users # type: ignore
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool: async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -181,13 +183,13 @@ class TotpSetupFlow(SetupFlow):
# to fix typing complaint # to fix typing complaint
self._auth_module: TotpAuthModule = auth_module self._auth_module: TotpAuthModule = auth_module
self._user = user self._user = user
self._ota_secret: Optional[str] = None self._ota_secret: str | None = None
self._url = None # type Optional[str] self._url = None # type Optional[str]
self._image = None # type Optional[str] self._image = None # type Optional[str]
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the first step of setup flow. """Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -195,7 +197,7 @@ class TotpSetupFlow(SetupFlow):
""" """
import pyotp # pylint: disable=import-outside-toplevel import pyotp # pylint: disable=import-outside-toplevel
errors: Dict[str, str] = {} errors: dict[str, str] = {}
if user_input: if user_input:
verified = await self.hass.async_add_executor_job( verified = await self.hass.async_add_executor_job(

View File

@ -1,7 +1,9 @@
"""Auth models.""" """Auth models."""
from __future__ import annotations
from datetime import datetime, timedelta from datetime import datetime, timedelta
import secrets import secrets
from typing import Dict, List, NamedTuple, Optional from typing import NamedTuple
import uuid import uuid
import attr import attr
@ -21,7 +23,7 @@ TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
class Group: class Group:
"""A group.""" """A group."""
name: Optional[str] = attr.ib() name: str | None = attr.ib()
policy: perm_mdl.PolicyType = attr.ib() policy: perm_mdl.PolicyType = attr.ib()
id: str = attr.ib(factory=lambda: uuid.uuid4().hex) id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
system_generated: bool = attr.ib(default=False) system_generated: bool = attr.ib(default=False)
@ -31,24 +33,24 @@ class Group:
class User: class User:
"""A user.""" """A user."""
name: Optional[str] = attr.ib() name: str | None = attr.ib()
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False) perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
id: str = attr.ib(factory=lambda: uuid.uuid4().hex) id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
is_owner: bool = attr.ib(default=False) is_owner: bool = attr.ib(default=False)
is_active: bool = attr.ib(default=False) is_active: bool = attr.ib(default=False)
system_generated: bool = attr.ib(default=False) system_generated: bool = attr.ib(default=False)
groups: List[Group] = attr.ib(factory=list, eq=False, order=False) groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
# List of credentials of a user. # List of credentials of a user.
credentials: List["Credentials"] = attr.ib(factory=list, eq=False, order=False) credentials: list["Credentials"] = attr.ib(factory=list, eq=False, order=False)
# Tokens associated with a user. # Tokens associated with a user.
refresh_tokens: Dict[str, "RefreshToken"] = attr.ib( refresh_tokens: dict[str, "RefreshToken"] = attr.ib(
factory=dict, eq=False, order=False factory=dict, eq=False, order=False
) )
_permissions: Optional[perm_mdl.PolicyPermissions] = attr.ib( _permissions: perm_mdl.PolicyPermissions | None = attr.ib(
init=False, init=False,
eq=False, eq=False,
order=False, order=False,
@ -89,10 +91,10 @@ class RefreshToken:
"""RefreshToken for a user to grant new access tokens.""" """RefreshToken for a user to grant new access tokens."""
user: User = attr.ib() user: User = attr.ib()
client_id: Optional[str] = attr.ib() client_id: str | None = attr.ib()
access_token_expiration: timedelta = attr.ib() access_token_expiration: timedelta = attr.ib()
client_name: Optional[str] = attr.ib(default=None) client_name: str | None = attr.ib(default=None)
client_icon: Optional[str] = attr.ib(default=None) client_icon: str | None = attr.ib(default=None)
token_type: str = attr.ib( token_type: str = attr.ib(
default=TOKEN_TYPE_NORMAL, default=TOKEN_TYPE_NORMAL,
validator=attr.validators.in_( validator=attr.validators.in_(
@ -104,12 +106,12 @@ class RefreshToken:
token: str = attr.ib(factory=lambda: secrets.token_hex(64)) token: str = attr.ib(factory=lambda: secrets.token_hex(64))
jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64)) jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64))
last_used_at: Optional[datetime] = attr.ib(default=None) last_used_at: datetime | None = attr.ib(default=None)
last_used_ip: Optional[str] = attr.ib(default=None) last_used_ip: str | None = attr.ib(default=None)
credential: Optional["Credentials"] = attr.ib(default=None) credential: "Credentials" | None = attr.ib(default=None)
version: Optional[str] = attr.ib(default=__version__) version: str | None = attr.ib(default=__version__)
@attr.s(slots=True) @attr.s(slots=True)
@ -117,7 +119,7 @@ class Credentials:
"""Credentials for a user on an auth provider.""" """Credentials for a user on an auth provider."""
auth_provider_type: str = attr.ib() auth_provider_type: str = attr.ib()
auth_provider_id: Optional[str] = attr.ib() auth_provider_id: str | None = attr.ib()
# Allow the auth provider to store data to represent their auth. # Allow the auth provider to store data to represent their auth.
data: dict = attr.ib() data: dict = attr.ib()
@ -129,5 +131,5 @@ class Credentials:
class UserMeta(NamedTuple): class UserMeta(NamedTuple):
"""User metadata.""" """User metadata."""
name: Optional[str] name: str | None
is_active: bool is_active: bool

View File

@ -1,6 +1,8 @@
"""Permissions for Home Assistant.""" """Permissions for Home Assistant."""
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Optional from typing import Any, Callable
import voluptuous as vol import voluptuous as vol
@ -19,7 +21,7 @@ _LOGGER = logging.getLogger(__name__)
class AbstractPermissions: class AbstractPermissions:
"""Default permissions class.""" """Default permissions class."""
_cached_entity_func: Optional[Callable[[str, str], bool]] = None _cached_entity_func: Callable[[str, str], bool] | None = None
def _entity_func(self) -> Callable[[str, str], bool]: def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access.""" """Return a function that can test entity access."""

View File

@ -1,6 +1,8 @@
"""Entity permissions.""" """Entity permissions."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, Optional from typing import Callable
import voluptuous as vol import voluptuous as vol
@ -43,14 +45,14 @@ ENTITY_POLICY_SCHEMA = vol.Any(
def _lookup_domain( def _lookup_domain(
perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]: ) -> ValueType | None:
"""Look up entity permissions by domain.""" """Look up entity permissions by domain."""
return domains_dict.get(entity_id.split(".", 1)[0]) return domains_dict.get(entity_id.split(".", 1)[0])
def _lookup_area( def _lookup_area(
perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]: ) -> ValueType | None:
"""Look up entity permissions by area.""" """Look up entity permissions by area."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id) entity_entry = perm_lookup.entity_registry.async_get(entity_id)
@ -67,7 +69,7 @@ def _lookup_area(
def _lookup_device( def _lookup_device(
perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]: ) -> ValueType | None:
"""Look up entity permissions by device.""" """Look up entity permissions by device."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id) entity_entry = perm_lookup.entity_registry.async_get(entity_id)
@ -79,7 +81,7 @@ def _lookup_device(
def _lookup_entity_id( def _lookup_entity_id(
perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]: ) -> ValueType | None:
"""Look up entity permission by entity id.""" """Look up entity permission by entity id."""
return entities_dict.get(entity_id) return entities_dict.get(entity_id)

View File

@ -1,13 +1,15 @@
"""Merging of policies.""" """Merging of policies."""
from typing import Dict, List, Set, cast from __future__ import annotations
from typing import cast
from .types import CategoryType, PolicyType from .types import CategoryType, PolicyType
def merge_policies(policies: List[PolicyType]) -> PolicyType: def merge_policies(policies: list[PolicyType]) -> PolicyType:
"""Merge policies.""" """Merge policies."""
new_policy: Dict[str, CategoryType] = {} new_policy: dict[str, CategoryType] = {}
seen: Set[str] = set() seen: set[str] = set()
for policy in policies: for policy in policies:
for category in policy: for category in policy:
if category in seen: if category in seen:
@ -20,7 +22,7 @@ def merge_policies(policies: List[PolicyType]) -> PolicyType:
return new_policy return new_policy
def _merge_policies(sources: List[CategoryType]) -> CategoryType: def _merge_policies(sources: list[CategoryType]) -> CategoryType:
"""Merge a policy.""" """Merge a policy."""
# When merging policies, the most permissive wins. # When merging policies, the most permissive wins.
# This means we order it like this: # This means we order it like this:
@ -34,7 +36,7 @@ def _merge_policies(sources: List[CategoryType]) -> CategoryType:
# merge each key in the source. # merge each key in the source.
policy: CategoryType = None policy: CategoryType = None
seen: Set[str] = set() seen: set[str] = set()
for source in sources: for source in sources:
if source is None: if source is None:
continue continue

View File

@ -1,6 +1,8 @@
"""Helpers to deal with permissions.""" """Helpers to deal with permissions."""
from __future__ import annotations
from functools import wraps from functools import wraps
from typing import Callable, Dict, List, Optional, cast from typing import Callable, Dict, Optional, cast
from .const import SUBCAT_ALL from .const import SUBCAT_ALL
from .models import PermissionLookup from .models import PermissionLookup
@ -45,7 +47,7 @@ def compile_policy(
assert isinstance(policy, dict) assert isinstance(policy, dict)
funcs: List[Callable[[str, str], Optional[bool]]] = [] funcs: list[Callable[[str, str], bool | None]] = []
for key, lookup_func in subcategories.items(): for key, lookup_func in subcategories.items():
lookup_value = policy.get(key) lookup_value = policy.get(key)
@ -80,10 +82,10 @@ def compile_policy(
def _gen_dict_test_func( def _gen_dict_test_func(
perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
) -> Callable[[str, str], Optional[bool]]: ) -> Callable[[str, str], bool | None]:
"""Generate a lookup function.""" """Generate a lookup function."""
def test_value(object_id: str, key: str) -> Optional[bool]: def test_value(object_id: str, key: str) -> bool | None:
"""Test if permission is allowed based on the keys.""" """Test if permission is allowed based on the keys."""
schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id) schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import importlib import importlib
import logging import logging
import types import types
from typing import Any, Dict, List, Optional from typing import Any
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -42,7 +42,7 @@ class AuthProvider:
DEFAULT_TITLE = "Unnamed auth provider" DEFAULT_TITLE = "Unnamed auth provider"
def __init__( def __init__(
self, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any] self, hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> None: ) -> None:
"""Initialize an auth provider.""" """Initialize an auth provider."""
self.hass = hass self.hass = hass
@ -50,7 +50,7 @@ class AuthProvider:
self.config = config self.config = config
@property @property
def id(self) -> Optional[str]: def id(self) -> str | None:
"""Return id of the auth provider. """Return id of the auth provider.
Optional, can be None. Optional, can be None.
@ -72,7 +72,7 @@ class AuthProvider:
"""Return whether multi-factor auth supported by the auth provider.""" """Return whether multi-factor auth supported by the auth provider."""
return True return True
async def async_credentials(self) -> List[Credentials]: async def async_credentials(self) -> list[Credentials]:
"""Return all credentials of this provider.""" """Return all credentials of this provider."""
users = await self.store.async_get_users() users = await self.store.async_get_users()
return [ return [
@ -86,7 +86,7 @@ class AuthProvider:
] ]
@callback @callback
def async_create_credentials(self, data: Dict[str, str]) -> Credentials: def async_create_credentials(self, data: dict[str, str]) -> Credentials:
"""Create credentials.""" """Create credentials."""
return Credentials( return Credentials(
auth_provider_type=self.type, auth_provider_id=self.id, data=data auth_provider_type=self.type, auth_provider_id=self.id, data=data
@ -94,7 +94,7 @@ class AuthProvider:
# Implement by extending class # Implement by extending class
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return the data flow for logging in with auth provider. """Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance. Auth provider should extend LoginFlow and return an instance.
@ -102,7 +102,7 @@ class AuthProvider:
raise NotImplementedError raise NotImplementedError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
raise NotImplementedError raise NotImplementedError
@ -121,7 +121,7 @@ class AuthProvider:
@callback @callback
def async_validate_refresh_token( def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None self, refresh_token: RefreshToken, remote_ip: str | None = None
) -> None: ) -> None:
"""Verify a refresh token is still valid. """Verify a refresh token is still valid.
@ -131,7 +131,7 @@ class AuthProvider:
async def auth_provider_from_config( async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any] hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> AuthProvider: ) -> AuthProvider:
"""Initialize an auth provider from a config.""" """Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE] provider_name = config[CONF_TYPE]
@ -188,17 +188,17 @@ class LoginFlow(data_entry_flow.FlowHandler):
def __init__(self, auth_provider: AuthProvider) -> None: def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider
self._auth_module_id: Optional[str] = None self._auth_module_id: str | None = None
self._auth_manager = auth_provider.hass.auth self._auth_manager = auth_provider.hass.auth
self.available_mfa_modules: Dict[str, str] = {} self.available_mfa_modules: dict[str, str] = {}
self.created_at = dt_util.utcnow() self.created_at = dt_util.utcnow()
self.invalid_mfa_times = 0 self.invalid_mfa_times = 0
self.user: Optional[User] = None self.user: User | None = None
self.credential: Optional[Credentials] = None self.credential: Credentials | None = None
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the first step of login flow. """Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -207,8 +207,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
raise NotImplementedError raise NotImplementedError
async def async_step_select_mfa_module( async def async_step_select_mfa_module(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of select mfa module.""" """Handle the step of select mfa module."""
errors = {} errors = {}
@ -232,8 +232,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
) )
async def async_step_mfa( async def async_step_mfa(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of mfa validation.""" """Handle the step of mfa validation."""
assert self.credential assert self.credential
assert self.user assert self.user
@ -273,7 +273,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
if not errors: if not errors:
return await self.async_finish(self.credential) return await self.async_finish(self.credential)
description_placeholders: Dict[str, Optional[str]] = { description_placeholders: dict[str, str | None] = {
"mfa_module_name": auth_module.name, "mfa_module_name": auth_module.name,
"mfa_module_id": auth_module.id, "mfa_module_id": auth_module.id,
} }
@ -285,6 +285,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors=errors, errors=errors,
) )
async def async_finish(self, flow_result: Any) -> Dict: async def async_finish(self, flow_result: Any) -> dict:
"""Handle the pass of login flow.""" """Handle the pass of login flow."""
return self.async_create_entry(title=self._auth_provider.name, data=flow_result) return self.async_create_entry(title=self._auth_provider.name, data=flow_result)

View File

@ -1,10 +1,11 @@
"""Auth provider that validates credentials via an external command.""" """Auth provider that validates credentials via an external command."""
from __future__ import annotations
import asyncio.subprocess import asyncio.subprocess
import collections import collections
import logging import logging
import os import os
from typing import Any, Dict, Optional, cast from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -51,9 +52,9 @@ class CommandLineAuthProvider(AuthProvider):
attributes provided by external programs. attributes provided by external programs.
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._user_meta: Dict[str, Dict[str, Any]] = {} self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: Optional[dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return CommandLineLoginFlow(self) return CommandLineLoginFlow(self)
@ -82,7 +83,7 @@ class CommandLineAuthProvider(AuthProvider):
raise InvalidAuthError raise InvalidAuthError
if self.config[CONF_META]: if self.config[CONF_META]:
meta: Dict[str, str] = {} meta: dict[str, str] = {}
for _line in stdout.splitlines(): for _line in stdout.splitlines():
try: try:
line = _line.decode().lstrip() line = _line.decode().lstrip()
@ -99,7 +100,7 @@ class CommandLineAuthProvider(AuthProvider):
self._user_meta[username] = meta self._user_meta[username] = meta
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
username = flow_result["username"] username = flow_result["username"]
@ -125,8 +126,8 @@ class CommandLineLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
@ -143,7 +144,7 @@ class CommandLineLoginFlow(LoginFlow):
user_input.pop("password") user_input.pop("password")
return await self.async_finish(user_input) return await self.async_finish(user_input)
schema: Dict[str, type] = collections.OrderedDict() schema: dict[str, type] = collections.OrderedDict()
schema["username"] = str schema["username"] = str
schema["password"] = str schema["password"] = str

View File

@ -5,7 +5,7 @@ import asyncio
import base64 import base64
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import Any, Dict, List, Optional, Set, cast from typing import Any, cast
import bcrypt import bcrypt
import voluptuous as vol import voluptuous as vol
@ -21,7 +21,7 @@ STORAGE_VERSION = 1
STORAGE_KEY = "auth_provider.homeassistant" STORAGE_KEY = "auth_provider.homeassistant"
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]: def _disallow_id(conf: dict[str, Any]) -> dict[str, Any]:
"""Disallow ID in config.""" """Disallow ID in config."""
if CONF_ID in conf: if CONF_ID in conf:
raise vol.Invalid("ID is not allowed for the homeassistant auth provider.") raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
@ -62,7 +62,7 @@ class Data:
self._store = hass.helpers.storage.Store( self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True STORAGE_VERSION, STORAGE_KEY, private=True
) )
self._data: Optional[Dict[str, Any]] = None self._data: dict[str, Any] | None = None
# Legacy mode will allow usernames to start/end with whitespace # Legacy mode will allow usernames to start/end with whitespace
# and will compare usernames case-insensitive. # and will compare usernames case-insensitive.
# Remove in 2020 or when we launch 1.0. # Remove in 2020 or when we launch 1.0.
@ -83,7 +83,7 @@ class Data:
if data is None: if data is None:
data = {"users": []} data = {"users": []}
seen: Set[str] = set() seen: set[str] = set()
for user in data["users"]: for user in data["users"]:
username = user["username"] username = user["username"]
@ -121,7 +121,7 @@ class Data:
self._data = data self._data = data
@property @property
def users(self) -> List[Dict[str, str]]: def users(self) -> list[dict[str, str]]:
"""Return users.""" """Return users."""
return self._data["users"] # type: ignore return self._data["users"] # type: ignore
@ -220,7 +220,7 @@ class HassAuthProvider(AuthProvider):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Home Assistant auth provider.""" """Initialize an Home Assistant auth provider."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.data: Optional[Data] = None self.data: Data | None = None
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
async def async_initialize(self) -> None: async def async_initialize(self) -> None:
@ -233,7 +233,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load() await data.async_load()
self.data = data self.data = data
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return HassLoginFlow(self) return HassLoginFlow(self)
@ -277,7 +277,7 @@ class HassAuthProvider(AuthProvider):
await self.data.async_save() await self.data.async_save()
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
if self.data is None: if self.data is None:
@ -318,8 +318,8 @@ class HassLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
@ -335,7 +335,7 @@ class HassLoginFlow(LoginFlow):
user_input.pop("password") user_input.pop("password")
return await self.async_finish(user_input) return await self.async_finish(user_input)
schema: Dict[str, type] = OrderedDict() schema: dict[str, type] = OrderedDict()
schema["username"] = str schema["username"] = str
schema["password"] = str schema["password"] = str

View File

@ -1,7 +1,9 @@
"""Example auth provider.""" """Example auth provider."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import hmac import hmac
from typing import Any, Dict, Optional, cast from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -33,7 +35,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider): class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords.""" """Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return ExampleLoginFlow(self) return ExampleLoginFlow(self)
@ -60,7 +62,7 @@ class ExampleAuthProvider(AuthProvider):
raise InvalidAuthError raise InvalidAuthError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
username = flow_result["username"] username = flow_result["username"]
@ -94,8 +96,8 @@ class ExampleLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
@ -111,7 +113,7 @@ class ExampleLoginFlow(LoginFlow):
user_input.pop("password") user_input.pop("password")
return await self.async_finish(user_input) return await self.async_finish(user_input)
schema: Dict[str, type] = OrderedDict() schema: dict[str, type] = OrderedDict()
schema["username"] = str schema["username"] = str
schema["password"] = str schema["password"] = str

View File

@ -3,8 +3,10 @@ Support Legacy API password auth provider.
It will be removed when auth system production ready It will be removed when auth system production ready
""" """
from __future__ import annotations
import hmac import hmac
from typing import Any, Dict, Optional, cast from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -40,7 +42,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
"""Return api_password.""" """Return api_password."""
return str(self.config[CONF_API_PASSWORD]) return str(self.config[CONF_API_PASSWORD])
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return LegacyLoginFlow(self) return LegacyLoginFlow(self)
@ -55,7 +57,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
raise InvalidAuthError raise InvalidAuthError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Return credentials for this login.""" """Return credentials for this login."""
credentials = await self.async_credentials() credentials = await self.async_credentials()
@ -79,8 +81,8 @@ class LegacyLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}

View File

@ -3,6 +3,8 @@
It shows list of users if access from trusted network. It shows list of users if access from trusted network.
Abort login flow if not access from trusted network. Abort login flow if not access from trusted network.
""" """
from __future__ import annotations
from ipaddress import ( from ipaddress import (
IPv4Address, IPv4Address,
IPv4Network, IPv4Network,
@ -11,7 +13,7 @@ from ipaddress import (
ip_address, ip_address,
ip_network, ip_network,
) )
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Union, cast
import voluptuous as vol import voluptuous as vol
@ -68,12 +70,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
DEFAULT_TITLE = "Trusted Networks" DEFAULT_TITLE = "Trusted Networks"
@property @property
def trusted_networks(self) -> List[IPNetwork]: def trusted_networks(self) -> list[IPNetwork]:
"""Return trusted networks.""" """Return trusted networks."""
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS]) return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
@property @property
def trusted_users(self) -> Dict[IPNetwork, Any]: def trusted_users(self) -> dict[IPNetwork, Any]:
"""Return trusted users per network.""" """Return trusted users per network."""
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS]) return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
@ -82,7 +84,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA.""" """Trusted Networks auth provider does not support MFA."""
return False return False
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address")) ip_addr = cast(IPAddress, context.get("ip_address"))
@ -125,7 +127,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
) )
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
user_id = flow_result["user"] user_id = flow_result["user"]
@ -169,7 +171,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
@callback @callback
def async_validate_refresh_token( def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None self, refresh_token: RefreshToken, remote_ip: str | None = None
) -> None: ) -> None:
"""Verify a refresh token is still valid.""" """Verify a refresh token is still valid."""
if remote_ip is None: if remote_ip is None:
@ -186,7 +188,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
self, self,
auth_provider: TrustedNetworksAuthProvider, auth_provider: TrustedNetworksAuthProvider,
ip_addr: IPAddress, ip_addr: IPAddress,
available_users: Dict[str, Optional[str]], available_users: dict[str, str | None],
allow_bypass_login: bool, allow_bypass_login: bool,
) -> None: ) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
@ -196,8 +198,8 @@ class TrustedNetworksLoginFlow(LoginFlow):
self._allow_bypass_login = allow_bypass_login self._allow_bypass_login = allow_bypass_login
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: dict[str, str] | None = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
try: try:
cast( cast(

View File

@ -1,11 +1,13 @@
"""Home Assistant command line scripts.""" """Home Assistant command line scripts."""
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import importlib import importlib
import logging import logging
import os import os
import sys import sys
from typing import List, Optional, Sequence, Text from typing import Sequence
from homeassistant import runner from homeassistant import runner
from homeassistant.bootstrap import async_mount_local_lib_path from homeassistant.bootstrap import async_mount_local_lib_path
@ -16,7 +18,7 @@ from homeassistant.util.package import install_package, is_installed, is_virtual
# mypy: allow-untyped-defs, no-warn-return-any # mypy: allow-untyped-defs, no-warn-return-any
def run(args: List) -> int: def run(args: list) -> int:
"""Run a script.""" """Run a script."""
scripts = [] scripts = []
path = os.path.dirname(__file__) path = os.path.dirname(__file__)
@ -65,7 +67,7 @@ def run(args: List) -> int:
return script.run(args[1:]) # type: ignore return script.run(args[1:]) # type: ignore
def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str: def extract_config_dir(args: Sequence[str] | None = None) -> str:
"""Extract the config dir from the arguments or get the default.""" """Extract the config dir from the arguments or get the default."""
parser = argparse.ArgumentParser(add_help=False) parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-c", "--config", default=None) parser.add_argument("-c", "--config", default=None)

View File

@ -1,4 +1,6 @@
"""Script to run benchmarks.""" """Script to run benchmarks."""
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import collections import collections
@ -7,7 +9,7 @@ from datetime import datetime
import json import json
import logging import logging
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import Callable, Dict, TypeVar from typing import Callable, TypeVar
from homeassistant import core from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.components.websocket_api.const import JSON_DUMP
@ -21,7 +23,7 @@ from homeassistant.util import dt as dt_util
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
BENCHMARKS: Dict[str, Callable] = {} BENCHMARKS: dict[str, Callable] = {}
def run(args): def run(args):

View File

@ -1,4 +1,6 @@
"""Script to check the configuration file.""" """Script to check the configuration file."""
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
@ -6,7 +8,7 @@ from collections.abc import Mapping, Sequence
from glob import glob from glob import glob
import logging import logging
import os import os
from typing import Any, Callable, Dict, List, Tuple from typing import Any, Callable
from unittest.mock import patch from unittest.mock import patch
from homeassistant import core from homeassistant import core
@ -22,13 +24,13 @@ REQUIREMENTS = ("colorlog==4.7.2",)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# pylint: disable=protected-access # pylint: disable=protected-access
MOCKS: Dict[str, Tuple[str, Callable]] = { MOCKS: dict[str, tuple[str, Callable]] = {
"load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml), "load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml),
"load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml), "load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml),
"secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml), "secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml),
} }
PATCHES: Dict[str, Any] = {} PATCHES: dict[str, Any] = {}
C_HEAD = "bold" C_HEAD = "bold"
ERROR_STR = "General Errors" ERROR_STR = "General Errors"
@ -48,7 +50,7 @@ def color(the_color, *args, reset=None):
raise ValueError(f"Invalid color {k!s} in {the_color}") from k raise ValueError(f"Invalid color {k!s} in {the_color}") from k
def run(script_args: List) -> int: def run(script_args: list) -> int:
"""Handle check config commandline script.""" """Handle check config commandline script."""
parser = argparse.ArgumentParser(description="Check Home Assistant configuration.") parser = argparse.ArgumentParser(description="Check Home Assistant configuration.")
parser.add_argument("--script", choices=["check_config"]) parser.add_argument("--script", choices=["check_config"])
@ -83,7 +85,7 @@ def run(script_args: List) -> int:
res = check(config_dir, args.secrets) res = check(config_dir, args.secrets)
domain_info: List[str] = [] domain_info: list[str] = []
if args.info: if args.info:
domain_info = args.info.split(",") domain_info = args.info.split(",")
@ -123,7 +125,7 @@ def run(script_args: List) -> int:
dump_dict(res["components"].get(domain)) dump_dict(res["components"].get(domain))
if args.secrets: if args.secrets:
flatsecret: Dict[str, str] = {} flatsecret: dict[str, str] = {}
for sfn, sdict in res["secret_cache"].items(): for sfn, sdict in res["secret_cache"].items():
sss = [] sss = []
@ -149,7 +151,7 @@ def run(script_args: List) -> int:
def check(config_dir, secrets=False): def check(config_dir, secrets=False):
"""Perform a check by mocking hass load functions.""" """Perform a check by mocking hass load functions."""
logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL) logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL)
res: Dict[str, Any] = { res: dict[str, Any] = {
"yaml_files": OrderedDict(), # yaml_files loaded "yaml_files": OrderedDict(), # yaml_files loaded
"secrets": OrderedDict(), # secret cache and secrets loaded "secrets": OrderedDict(), # secret cache and secrets loaded
"except": OrderedDict(), # exceptions raised (with config) "except": OrderedDict(), # exceptions raised (with config)

View File

@ -1,4 +1,6 @@
"""Helper methods for various modules.""" """Helper methods for various modules."""
from __future__ import annotations
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
import enum import enum
@ -9,16 +11,7 @@ import socket
import string import string
import threading import threading
from types import MappingProxyType from types import MappingProxyType
from typing import ( from typing import Any, Callable, Coroutine, Iterable, KeysView, TypeVar
Any,
Callable,
Coroutine,
Iterable,
KeysView,
Optional,
TypeVar,
Union,
)
import slugify as unicode_slug import slugify as unicode_slug
@ -106,8 +99,8 @@ def repr_helper(inp: Any) -> str:
def convert( def convert(
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None value: T | None, to_type: Callable[[T], U], default: U | None = None
) -> Optional[U]: ) -> U | None:
"""Convert value to to_type, returns default if fails.""" """Convert value to to_type, returns default if fails."""
try: try:
return default if value is None else to_type(value) return default if value is None else to_type(value)
@ -117,7 +110,7 @@ def convert(
def ensure_unique_string( def ensure_unique_string(
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]] preferred_string: str, current_strings: Iterable[str] | KeysView[str]
) -> str: ) -> str:
"""Return a string that is not present in current_strings. """Return a string that is not present in current_strings.
@ -213,7 +206,7 @@ class Throttle:
""" """
def __init__( def __init__(
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None self, min_time: timedelta, limit_no_throttle: timedelta | None = None
) -> None: ) -> None:
"""Initialize the throttle.""" """Initialize the throttle."""
self.min_time = min_time self.min_time = min_time
@ -253,7 +246,7 @@ class Throttle:
) )
@wraps(method) @wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]: def wrapper(*args: Any, **kwargs: Any) -> Callable | Coroutine:
"""Wrap that allows wrapped to be called only once per min_time. """Wrap that allows wrapped to be called only once per min_time.
If we cannot acquire the lock, it is running so return None. If we cannot acquire the lock, it is running so return None.

View File

@ -1,7 +1,9 @@
"""Utilities to help with aiohttp.""" """Utilities to help with aiohttp."""
from __future__ import annotations
import io import io
import json import json
from typing import Any, Dict, Optional from typing import Any
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from multidict import CIMultiDict, MultiDict from multidict import CIMultiDict, MultiDict
@ -26,7 +28,7 @@ class MockStreamReader:
class MockRequest: class MockRequest:
"""Mock an aiohttp request.""" """Mock an aiohttp request."""
mock_source: Optional[str] = None mock_source: str | None = None
def __init__( def __init__(
self, self,
@ -34,8 +36,8 @@ class MockRequest:
mock_source: str, mock_source: str,
method: str = "GET", method: str = "GET",
status: int = HTTP_OK, status: int = HTTP_OK,
headers: Optional[Dict[str, str]] = None, headers: dict[str, str] | None = None,
query_string: Optional[str] = None, query_string: str | None = None,
url: str = "", url: str = "",
) -> None: ) -> None:
"""Initialize a request.""" """Initialize a request."""

View File

@ -1,7 +1,8 @@
"""Color util methods.""" """Color util methods."""
from __future__ import annotations
import colorsys import colorsys
import math import math
from typing import List, Optional, Tuple
import attr import attr
@ -183,7 +184,7 @@ class GamutType:
blue: XYPoint = attr.ib() blue: XYPoint = attr.ib()
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]: def color_name_to_rgb(color_name: str) -> tuple[int, int, int]:
"""Convert color name to RGB hex value.""" """Convert color name to RGB hex value."""
# COLORS map has no spaces in it, so make the color_name have no # COLORS map has no spaces in it, so make the color_name have no
# spaces in it as well for matching purposes # spaces in it as well for matching purposes
@ -198,8 +199,8 @@ def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
def color_RGB_to_xy( def color_RGB_to_xy(
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None iR: int, iG: int, iB: int, Gamut: GamutType | None = None
) -> Tuple[float, float]: ) -> tuple[float, float]:
"""Convert from RGB color to XY color.""" """Convert from RGB color to XY color."""
return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2] return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2]
@ -208,8 +209,8 @@ def color_RGB_to_xy(
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy # http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
# License: Code is given as is. Use at your own risk and discretion. # License: Code is given as is. Use at your own risk and discretion.
def color_RGB_to_xy_brightness( def color_RGB_to_xy_brightness(
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None iR: int, iG: int, iB: int, Gamut: GamutType | None = None
) -> Tuple[float, float, int]: ) -> tuple[float, float, int]:
"""Convert from RGB color to XY color.""" """Convert from RGB color to XY color."""
if iR + iG + iB == 0: if iR + iG + iB == 0:
return 0.0, 0.0, 0 return 0.0, 0.0, 0
@ -248,8 +249,8 @@ def color_RGB_to_xy_brightness(
def color_xy_to_RGB( def color_xy_to_RGB(
vX: float, vY: float, Gamut: Optional[GamutType] = None vX: float, vY: float, Gamut: GamutType | None = None
) -> Tuple[int, int, int]: ) -> tuple[int, int, int]:
"""Convert from XY to a normalized RGB.""" """Convert from XY to a normalized RGB."""
return color_xy_brightness_to_RGB(vX, vY, 255, Gamut) return color_xy_brightness_to_RGB(vX, vY, 255, Gamut)
@ -257,8 +258,8 @@ def color_xy_to_RGB(
# Converted to Python from Obj-C, original source from: # Converted to Python from Obj-C, original source from:
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy # http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
def color_xy_brightness_to_RGB( def color_xy_brightness_to_RGB(
vX: float, vY: float, ibrightness: int, Gamut: Optional[GamutType] = None vX: float, vY: float, ibrightness: int, Gamut: GamutType | None = None
) -> Tuple[int, int, int]: ) -> tuple[int, int, int]:
"""Convert from XYZ to RGB.""" """Convert from XYZ to RGB."""
if Gamut: if Gamut:
if not check_point_in_lamps_reach((vX, vY), Gamut): if not check_point_in_lamps_reach((vX, vY), Gamut):
@ -304,7 +305,7 @@ def color_xy_brightness_to_RGB(
return (ir, ig, ib) return (ir, ig, ib)
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]: def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> tuple[int, int, int]:
"""Convert a hsb into its rgb representation.""" """Convert a hsb into its rgb representation."""
if fS == 0.0: if fS == 0.0:
fV = int(fB * 255) fV = int(fB * 255)
@ -345,7 +346,7 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
return (r, g, b) return (r, g, b)
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, float]: def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> tuple[float, float, float]:
"""Convert an rgb color to its hsv representation. """Convert an rgb color to its hsv representation.
Hue is scaled 0-360 Hue is scaled 0-360
@ -356,12 +357,12 @@ def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, flo
return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3) return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3)
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]: def color_RGB_to_hs(iR: float, iG: float, iB: float) -> tuple[float, float]:
"""Convert an rgb color to its hs representation.""" """Convert an rgb color to its hs representation."""
return color_RGB_to_hsv(iR, iG, iB)[:2] return color_RGB_to_hsv(iR, iG, iB)[:2]
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]: def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> tuple[int, int, int]:
"""Convert an hsv color into its rgb representation. """Convert an hsv color into its rgb representation.
Hue is scaled 0-360 Hue is scaled 0-360
@ -372,27 +373,27 @@ def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255)) return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255))
def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]: def color_hs_to_RGB(iH: float, iS: float) -> tuple[int, int, int]:
"""Convert an hsv color into its rgb representation.""" """Convert an hsv color into its rgb representation."""
return color_hsv_to_RGB(iH, iS, 100) return color_hsv_to_RGB(iH, iS, 100)
def color_xy_to_hs( def color_xy_to_hs(
vX: float, vY: float, Gamut: Optional[GamutType] = None vX: float, vY: float, Gamut: GamutType | None = None
) -> Tuple[float, float]: ) -> tuple[float, float]:
"""Convert an xy color to its hs representation.""" """Convert an xy color to its hs representation."""
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut)) h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut))
return h, s return h, s
def color_hs_to_xy( def color_hs_to_xy(
iH: float, iS: float, Gamut: Optional[GamutType] = None iH: float, iS: float, Gamut: GamutType | None = None
) -> Tuple[float, float]: ) -> tuple[float, float]:
"""Convert an hs color to its xy representation.""" """Convert an hs color to its xy representation."""
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut) return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut)
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple: def _match_max_scale(input_colors: tuple, output_colors: tuple) -> tuple:
"""Match the maximum value of the output to the input.""" """Match the maximum value of the output to the input."""
max_in = max(input_colors) max_in = max(input_colors)
max_out = max(output_colors) max_out = max(output_colors)
@ -403,7 +404,7 @@ def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
return tuple(int(round(i * factor)) for i in output_colors) return tuple(int(round(i * factor)) for i in output_colors)
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]: def color_rgb_to_rgbw(r: int, g: int, b: int) -> tuple[int, int, int, int]:
"""Convert an rgb color to an rgbw representation.""" """Convert an rgb color to an rgbw representation."""
# Calculate the white channel as the minimum of input rgb channels. # Calculate the white channel as the minimum of input rgb channels.
# Subtract the white portion from the remaining rgb channels. # Subtract the white portion from the remaining rgb channels.
@ -415,7 +416,7 @@ def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
return _match_max_scale((r, g, b), rgbw) # type: ignore return _match_max_scale((r, g, b), rgbw) # type: ignore
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]: def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> tuple[int, int, int]:
"""Convert an rgbw color to an rgb representation.""" """Convert an rgbw color to an rgb representation."""
# Add the white channel back into the rgb channels. # Add the white channel back into the rgb channels.
rgb = (r + w, g + w, b + w) rgb = (r + w, g + w, b + w)
@ -430,7 +431,7 @@ def color_rgb_to_hex(r: int, g: int, b: int) -> str:
return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b)) return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b))
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]: def rgb_hex_to_rgb_list(hex_string: str) -> list[int]:
"""Return an RGB color value list from a hex color string.""" """Return an RGB color value list from a hex color string."""
return [ return [
int(hex_string[i : i + len(hex_string) // 3], 16) int(hex_string[i : i + len(hex_string) // 3], 16)
@ -438,14 +439,14 @@ def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
] ]
def color_temperature_to_hs(color_temperature_kelvin: float) -> Tuple[float, float]: def color_temperature_to_hs(color_temperature_kelvin: float) -> tuple[float, float]:
"""Return an hs color from a color temperature in Kelvin.""" """Return an hs color from a color temperature in Kelvin."""
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin)) return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
def color_temperature_to_rgb( def color_temperature_to_rgb(
color_temperature_kelvin: float, color_temperature_kelvin: float,
) -> Tuple[float, float, float]: ) -> tuple[float, float, float]:
""" """
Return an RGB color from a color temperature in Kelvin. Return an RGB color from a color temperature in Kelvin.
@ -555,8 +556,8 @@ def get_closest_point_to_line(A: XYPoint, B: XYPoint, P: XYPoint) -> XYPoint:
def get_closest_point_to_point( def get_closest_point_to_point(
xy_tuple: Tuple[float, float], Gamut: GamutType xy_tuple: tuple[float, float], Gamut: GamutType
) -> Tuple[float, float]: ) -> tuple[float, float]:
""" """
Get the closest matching color within the gamut of the light. Get the closest matching color within the gamut of the light.
@ -592,7 +593,7 @@ def get_closest_point_to_point(
return (cx, cy) return (cx, cy)
def check_point_in_lamps_reach(p: Tuple[float, float], Gamut: GamutType) -> bool: def check_point_in_lamps_reach(p: tuple[float, float], Gamut: GamutType) -> bool:
"""Check if the provided XYPoint can be recreated by a Hue lamp.""" """Check if the provided XYPoint can be recreated by a Hue lamp."""
v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y) v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y)
v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y) v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y)

View File

@ -1,6 +1,8 @@
"""Distance util functions.""" """Distance util functions."""
from __future__ import annotations
from numbers import Number from numbers import Number
from typing import Callable, Dict from typing import Callable
from homeassistant.const import ( from homeassistant.const import (
LENGTH, LENGTH,
@ -26,7 +28,7 @@ VALID_UNITS = [
LENGTH_YARD, LENGTH_YARD,
] ]
TO_METERS: Dict[str, Callable[[float], float]] = { TO_METERS: dict[str, Callable[[float], float]] = {
LENGTH_METERS: lambda meters: meters, LENGTH_METERS: lambda meters: meters,
LENGTH_MILES: lambda miles: miles * 1609.344, LENGTH_MILES: lambda miles: miles * 1609.344,
LENGTH_YARD: lambda yards: yards * 0.9144, LENGTH_YARD: lambda yards: yards * 0.9144,
@ -37,7 +39,7 @@ TO_METERS: Dict[str, Callable[[float], float]] = {
LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001, LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001,
} }
METERS_TO: Dict[str, Callable[[float], float]] = { METERS_TO: dict[str, Callable[[float], float]] = {
LENGTH_METERS: lambda meters: meters, LENGTH_METERS: lambda meters: meters,
LENGTH_MILES: lambda meters: meters * 0.000621371, LENGTH_MILES: lambda meters: meters * 0.000621371,
LENGTH_YARD: lambda meters: meters * 1.09361, LENGTH_YARD: lambda meters: meters * 1.09361,

View File

@ -1,7 +1,9 @@
"""Helper methods to handle the time in Home Assistant.""" """Helper methods to handle the time in Home Assistant."""
from __future__ import annotations
import datetime as dt import datetime as dt
import re import re
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, cast
import ciso8601 import ciso8601
import pytz import pytz
@ -40,7 +42,7 @@ def set_default_time_zone(time_zone: dt.tzinfo) -> None:
DEFAULT_TIME_ZONE = time_zone DEFAULT_TIME_ZONE = time_zone
def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]: def get_time_zone(time_zone_str: str) -> dt.tzinfo | None:
"""Get time zone from string. Return None if unable to determine. """Get time zone from string. Return None if unable to determine.
Async friendly. Async friendly.
@ -56,7 +58,7 @@ def utcnow() -> dt.datetime:
return dt.datetime.now(NATIVE_UTC) return dt.datetime.now(NATIVE_UTC)
def now(time_zone: Optional[dt.tzinfo] = None) -> dt.datetime: def now(time_zone: dt.tzinfo | None = None) -> dt.datetime:
"""Get now in specified time zone.""" """Get now in specified time zone."""
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE) return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
@ -77,7 +79,7 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
def as_timestamp(dt_value: dt.datetime) -> float: def as_timestamp(dt_value: dt.datetime) -> float:
"""Convert a date/time into a unix time (seconds since 1970).""" """Convert a date/time into a unix time (seconds since 1970)."""
if hasattr(dt_value, "timestamp"): if hasattr(dt_value, "timestamp"):
parsed_dt: Optional[dt.datetime] = dt_value parsed_dt: dt.datetime | None = dt_value
else: else:
parsed_dt = parse_datetime(str(dt_value)) parsed_dt = parse_datetime(str(dt_value))
if parsed_dt is None: if parsed_dt is None:
@ -100,9 +102,7 @@ def utc_from_timestamp(timestamp: float) -> dt.datetime:
return UTC.localize(dt.datetime.utcfromtimestamp(timestamp)) return UTC.localize(dt.datetime.utcfromtimestamp(timestamp))
def start_of_local_day( def start_of_local_day(dt_or_d: dt.date | dt.datetime | None = None) -> dt.datetime:
dt_or_d: Union[dt.date, dt.datetime, None] = None
) -> dt.datetime:
"""Return local datetime object of start of day from date or datetime.""" """Return local datetime object of start of day from date or datetime."""
if dt_or_d is None: if dt_or_d is None:
date: dt.date = now().date() date: dt.date = now().date()
@ -119,7 +119,7 @@ def start_of_local_day(
# Copyright (c) Django Software Foundation and individual contributors. # Copyright (c) Django Software Foundation and individual contributors.
# All rights reserved. # All rights reserved.
# https://github.com/django/django/blob/master/LICENSE # https://github.com/django/django/blob/master/LICENSE
def parse_datetime(dt_str: str) -> Optional[dt.datetime]: def parse_datetime(dt_str: str) -> dt.datetime | None:
"""Parse a string and return a datetime.datetime. """Parse a string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one, This function supports time zone offsets. When the input contains one,
@ -134,12 +134,12 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
match = DATETIME_RE.match(dt_str) match = DATETIME_RE.match(dt_str)
if not match: if not match:
return None return None
kws: Dict[str, Any] = match.groupdict() kws: dict[str, Any] = match.groupdict()
if kws["microsecond"]: if kws["microsecond"]:
kws["microsecond"] = kws["microsecond"].ljust(6, "0") kws["microsecond"] = kws["microsecond"].ljust(6, "0")
tzinfo_str = kws.pop("tzinfo") tzinfo_str = kws.pop("tzinfo")
tzinfo: Optional[dt.tzinfo] = None tzinfo: dt.tzinfo | None = None
if tzinfo_str == "Z": if tzinfo_str == "Z":
tzinfo = UTC tzinfo = UTC
elif tzinfo_str is not None: elif tzinfo_str is not None:
@ -154,7 +154,7 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
return dt.datetime(**kws) return dt.datetime(**kws)
def parse_date(dt_str: str) -> Optional[dt.date]: def parse_date(dt_str: str) -> dt.date | None:
"""Convert a date string to a date object.""" """Convert a date string to a date object."""
try: try:
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date() return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
@ -162,7 +162,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
return None return None
def parse_time(time_str: str) -> Optional[dt.time]: def parse_time(time_str: str) -> dt.time | None:
"""Parse a time string (00:20:00) into Time object. """Parse a time string (00:20:00) into Time object.
Return None if invalid. Return None if invalid.
@ -213,7 +213,7 @@ def get_age(date: dt.datetime) -> str:
return formatn(rounded_delta, selected_unit) return formatn(rounded_delta, selected_unit)
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> List[int]: def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> list[int]:
"""Parse the time expression part and return a list of times to match.""" """Parse the time expression part and return a list of times to match."""
if parameter is None or parameter == MATCH_ALL: if parameter is None or parameter == MATCH_ALL:
res = list(range(min_value, max_value + 1)) res = list(range(min_value, max_value + 1))
@ -241,9 +241,9 @@ def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> Lis
def find_next_time_expression_time( def find_next_time_expression_time(
now: dt.datetime, # pylint: disable=redefined-outer-name now: dt.datetime, # pylint: disable=redefined-outer-name
seconds: List[int], seconds: list[int],
minutes: List[int], minutes: list[int],
hours: List[int], hours: list[int],
) -> dt.datetime: ) -> dt.datetime:
"""Find the next datetime from now for which the time expression matches. """Find the next datetime from now for which the time expression matches.
@ -257,7 +257,7 @@ def find_next_time_expression_time(
if not seconds or not minutes or not hours: if not seconds or not minutes or not hours:
raise ValueError("Cannot find a next time: Time expression never matches!") raise ValueError("Cannot find a next time: Time expression never matches!")
def _lower_bound(arr: List[int], cmp: int) -> Optional[int]: def _lower_bound(arr: list[int], cmp: int) -> int | None:
"""Return the first value in arr greater or equal to cmp. """Return the first value in arr greater or equal to cmp.
Return None if no such value exists. Return None if no such value exists.

View File

@ -1,10 +1,12 @@
"""JSON utility functions.""" """JSON utility functions."""
from __future__ import annotations
from collections import deque from collections import deque
import json import json
import logging import logging
import os import os
import tempfile import tempfile
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@ -20,9 +22,7 @@ class WriteError(HomeAssistantError):
"""Error writing the data.""" """Error writing the data."""
def load_json( def load_json(filename: str, default: list | dict | None = None) -> list | dict:
filename: str, default: Union[List, Dict, None] = None
) -> Union[List, Dict]:
"""Load JSON data from a file and return as dict or list. """Load JSON data from a file and return as dict or list.
Defaults to returning empty dict if file is not found. Defaults to returning empty dict if file is not found.
@ -44,10 +44,10 @@ def load_json(
def save_json( def save_json(
filename: str, filename: str,
data: Union[List, Dict], data: list | dict,
private: bool = False, private: bool = False,
*, *,
encoder: Optional[Type[json.JSONEncoder]] = None, encoder: type[json.JSONEncoder] | None = None,
) -> None: ) -> None:
"""Save JSON data to a file. """Save JSON data to a file.
@ -85,7 +85,7 @@ def save_json(
_LOGGER.error("JSON replacement cleanup failed: %s", err) _LOGGER.error("JSON replacement cleanup failed: %s", err)
def format_unserializable_data(data: Dict[str, Any]) -> str: def format_unserializable_data(data: dict[str, Any]) -> str:
"""Format output of find_paths in a friendly way. """Format output of find_paths in a friendly way.
Format is comma separated: <path>=<value>(<type>) Format is comma separated: <path>=<value>(<type>)
@ -95,7 +95,7 @@ def format_unserializable_data(data: Dict[str, Any]) -> str:
def find_paths_unserializable_data( def find_paths_unserializable_data(
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Find the paths to unserializable data. """Find the paths to unserializable data.
This method is slow! Only use for error handling. This method is slow! Only use for error handling.

View File

@ -3,10 +3,12 @@ Module with location helpers.
detect_location_info and elevation are mocked by default during tests. detect_location_info and elevation are mocked by default during tests.
""" """
from __future__ import annotations
import asyncio import asyncio
import collections import collections
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any
import aiohttp import aiohttp
@ -47,7 +49,7 @@ LocationInfo = collections.namedtuple(
async def async_detect_location_info( async def async_detect_location_info(
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Optional[LocationInfo]: ) -> LocationInfo | None:
"""Detect location information.""" """Detect location information."""
data = await _get_ipapi(session) data = await _get_ipapi(session)
@ -63,8 +65,8 @@ async def async_detect_location_info(
def distance( def distance(
lat1: Optional[float], lon1: Optional[float], lat2: float, lon2: float lat1: float | None, lon1: float | None, lat2: float, lon2: float
) -> Optional[float]: ) -> float | None:
"""Calculate the distance in meters between two points. """Calculate the distance in meters between two points.
Async friendly. Async friendly.
@ -81,8 +83,8 @@ def distance(
# Source: https://github.com/maurycyp/vincenty # Source: https://github.com/maurycyp/vincenty
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE # License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
def vincenty( def vincenty(
point1: Tuple[float, float], point2: Tuple[float, float], miles: bool = False point1: tuple[float, float], point2: tuple[float, float], miles: bool = False
) -> Optional[float]: ) -> float | None:
""" """
Vincenty formula (inverse method) to calculate the distance. Vincenty formula (inverse method) to calculate the distance.
@ -162,7 +164,7 @@ def vincenty(
return round(s, 6) return round(s, 6)
async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]: async def _get_ipapi(session: aiohttp.ClientSession) -> dict[str, Any] | None:
"""Query ipapi.co for location data.""" """Query ipapi.co for location data."""
try: try:
resp = await session.get(IPAPI, timeout=5) resp = await session.get(IPAPI, timeout=5)
@ -192,7 +194,7 @@ async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]
} }
async def _get_ip_api(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]: async def _get_ip_api(session: aiohttp.ClientSession) -> dict[str, Any] | None:
"""Query ip-api.com for location data.""" """Query ip-api.com for location data."""
try: try:
resp = await session.get(IP_API, timeout=5) resp = await session.get(IP_API, timeout=5)

View File

@ -1,4 +1,6 @@
"""Logging utilities.""" """Logging utilities."""
from __future__ import annotations
import asyncio import asyncio
from functools import partial, wraps from functools import partial, wraps
import inspect import inspect
@ -6,7 +8,7 @@ import logging
import logging.handlers import logging.handlers
import queue import queue
import traceback import traceback
from typing import Any, Awaitable, Callable, Coroutine, Union, cast, overload from typing import Any, Awaitable, Callable, Coroutine, cast, overload
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -115,7 +117,7 @@ def catch_log_exception(
def catch_log_exception( def catch_log_exception(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> Union[Callable[..., None], Callable[..., Awaitable[None]]]: ) -> Callable[..., None] | Callable[..., Awaitable[None]]:
"""Decorate a callback to catch and log exceptions.""" """Decorate a callback to catch and log exceptions."""
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
@ -123,7 +125,7 @@ def catch_log_exception(
while isinstance(check_func, partial): while isinstance(check_func, partial):
check_func = check_func.func check_func = check_func.func
wrapper_func: Union[Callable[..., None], Callable[..., Awaitable[None]]] wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]]
if asyncio.iscoroutinefunction(check_func): if asyncio.iscoroutinefunction(check_func):
async_func = cast(Callable[..., Awaitable[None]], func) async_func = cast(Callable[..., Awaitable[None]], func)

View File

@ -1,6 +1,7 @@
"""Network utilities.""" """Network utilities."""
from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
from typing import Union
import yarl import yarl
@ -23,22 +24,22 @@ PRIVATE_NETWORKS = (
LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16") LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16")
def is_loopback(address: Union[IPv4Address, IPv6Address]) -> bool: def is_loopback(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is a loopback address.""" """Check if an address is a loopback address."""
return any(address in network for network in LOOPBACK_NETWORKS) return any(address in network for network in LOOPBACK_NETWORKS)
def is_private(address: Union[IPv4Address, IPv6Address]) -> bool: def is_private(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is a private address.""" """Check if an address is a private address."""
return any(address in network for network in PRIVATE_NETWORKS) return any(address in network for network in PRIVATE_NETWORKS)
def is_link_local(address: Union[IPv4Address, IPv6Address]) -> bool: def is_link_local(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is link local.""" """Check if an address is link local."""
return address in LINK_LOCAL_NETWORK return address in LINK_LOCAL_NETWORK
def is_local(address: Union[IPv4Address, IPv6Address]) -> bool: def is_local(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is loopback or private.""" """Check if an address is loopback or private."""
return is_loopback(address) or is_private(address) return is_loopback(address) or is_private(address)

View File

@ -1,4 +1,6 @@
"""Helpers to install PyPi packages.""" """Helpers to install PyPi packages."""
from __future__ import annotations
import asyncio import asyncio
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
import logging import logging
@ -6,7 +8,6 @@ import os
from pathlib import Path from pathlib import Path
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
import sys import sys
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import pkg_resources import pkg_resources
@ -59,10 +60,10 @@ def is_installed(package: str) -> bool:
def install_package( def install_package(
package: str, package: str,
upgrade: bool = True, upgrade: bool = True,
target: Optional[str] = None, target: str | None = None,
constraints: Optional[str] = None, constraints: str | None = None,
find_links: Optional[str] = None, find_links: str | None = None,
no_cache_dir: Optional[bool] = False, no_cache_dir: bool | None = False,
) -> bool: ) -> bool:
"""Install a package on PyPi. Accepts pip compatible package strings. """Install a package on PyPi. Accepts pip compatible package strings.

View File

@ -1,9 +1,8 @@
"""Percentage util functions.""" """Percentage util functions."""
from __future__ import annotations
from typing import List, Tuple
def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int: def ordered_list_item_to_percentage(ordered_list: list[str], item: str) -> int:
"""Determine the percentage of an item in an ordered list. """Determine the percentage of an item in an ordered list.
When using this utility for fan speeds, do not include "off" When using this utility for fan speeds, do not include "off"
@ -26,7 +25,7 @@ def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
return (list_position * 100) // list_len return (list_position * 100) // list_len
def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) -> str: def percentage_to_ordered_list_item(ordered_list: list[str], percentage: int) -> str:
"""Find the item that most closely matches the percentage in an ordered list. """Find the item that most closely matches the percentage in an ordered list.
When using this utility for fan speeds, do not include "off" When using this utility for fan speeds, do not include "off"
@ -54,7 +53,7 @@ def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) ->
def ranged_value_to_percentage( def ranged_value_to_percentage(
low_high_range: Tuple[float, float], value: float low_high_range: tuple[float, float], value: float
) -> int: ) -> int:
"""Given a range of low and high values convert a single value to a percentage. """Given a range of low and high values convert a single value to a percentage.
@ -71,7 +70,7 @@ def ranged_value_to_percentage(
def percentage_to_ranged_value( def percentage_to_ranged_value(
low_high_range: Tuple[float, float], percentage: int low_high_range: tuple[float, float], percentage: int
) -> float: ) -> float:
"""Given a range of low and high values convert a percentage to a single value. """Given a range of low and high values convert a percentage to a single value.
@ -87,11 +86,11 @@ def percentage_to_ranged_value(
return states_in_range(low_high_range) * percentage / 100 return states_in_range(low_high_range) * percentage / 100
def states_in_range(low_high_range: Tuple[float, float]) -> float: def states_in_range(low_high_range: tuple[float, float]) -> float:
"""Given a range of low and high values return how many states exist.""" """Given a range of low and high values return how many states exist."""
return low_high_range[1] - low_high_range[0] + 1 return low_high_range[1] - low_high_range[0] + 1
def int_states_in_range(low_high_range: Tuple[float, float]) -> int: def int_states_in_range(low_high_range: tuple[float, float]) -> int:
"""Given a range of low and high values return how many integer states exist.""" """Given a range of low and high values return how many integer states exist."""
return int(states_in_range(low_high_range)) return int(states_in_range(low_high_range))

View File

@ -2,18 +2,18 @@
Can only be used by integrations that have pillow in their requirements. Can only be used by integrations that have pillow in their requirements.
""" """
from typing import Tuple from __future__ import annotations
from PIL import ImageDraw from PIL import ImageDraw
def draw_box( def draw_box(
draw: ImageDraw, draw: ImageDraw,
box: Tuple[float, float, float, float], box: tuple[float, float, float, float],
img_width: int, img_width: int,
img_height: int, img_height: int,
text: str = "", text: str = "",
color: Tuple[int, int, int] = (255, 255, 0), color: tuple[int, int, int] = (255, 255, 0),
) -> None: ) -> None:
""" """
Draw a bounding box on and image. Draw a bounding box on and image.

View File

@ -1,9 +1,11 @@
"""ruamel.yaml utility functions.""" """ruamel.yaml utility functions."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import logging import logging
import os import os
from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result
from typing import Dict, List, Optional, Union from typing import Dict, List, Union
import ruamel.yaml import ruamel.yaml
from ruamel.yaml import YAML # type: ignore from ruamel.yaml import YAML # type: ignore
@ -22,7 +24,7 @@ JSON_TYPE = Union[List, Dict, str] # pylint: disable=invalid-name
class ExtSafeConstructor(SafeConstructor): class ExtSafeConstructor(SafeConstructor):
"""Extended SafeConstructor.""" """Extended SafeConstructor."""
name: Optional[str] = None name: str | None = None
class UnsupportedYamlError(HomeAssistantError): class UnsupportedYamlError(HomeAssistantError):
@ -77,7 +79,7 @@ def yaml_to_object(data: str) -> JSON_TYPE:
"""Create object from yaml string.""" """Create object from yaml string."""
yaml = YAML(typ="rt") yaml = YAML(typ="rt")
try: try:
result: Union[List, Dict, str] = yaml.load(data) result: list | dict | str = yaml.load(data)
return result return result
except YAMLError as exc: except YAMLError as exc:
_LOGGER.error("YAML error: %s", exc) _LOGGER.error("YAML error: %s", exc)

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import asyncio import asyncio
import enum import enum
from types import TracebackType from types import TracebackType
from typing import Any, Dict, List, Optional, Type, Union from typing import Any
from .async_ import run_callback_threadsafe from .async_ import run_callback_threadsafe
@ -38,10 +38,10 @@ class _GlobalFreezeContext:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._exit() self._exit()
return None return None
@ -51,10 +51,10 @@ class _GlobalFreezeContext:
def __exit__( # pylint: disable=useless-return def __exit__( # pylint: disable=useless-return
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._loop.call_soon_threadsafe(self._exit) self._loop.call_soon_threadsafe(self._exit)
return None return None
@ -106,10 +106,10 @@ class _ZoneFreezeContext:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._exit() self._exit()
return None return None
@ -119,10 +119,10 @@ class _ZoneFreezeContext:
def __exit__( # pylint: disable=useless-return def __exit__( # pylint: disable=useless-return
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._loop.call_soon_threadsafe(self._exit) self._loop.call_soon_threadsafe(self._exit)
return None return None
@ -155,8 +155,8 @@ class _GlobalTaskContext:
self._manager: TimeoutManager = manager self._manager: TimeoutManager = manager
self._task: asyncio.Task[Any] = task self._task: asyncio.Task[Any] = task
self._time_left: float = timeout self._time_left: float = timeout
self._expiration_time: Optional[float] = None self._expiration_time: float | None = None
self._timeout_handler: Optional[asyncio.Handle] = None self._timeout_handler: asyncio.Handle | None = None
self._wait_zone: asyncio.Event = asyncio.Event() self._wait_zone: asyncio.Event = asyncio.Event()
self._state: _State = _State.INIT self._state: _State = _State.INIT
self._cool_down: float = cool_down self._cool_down: float = cool_down
@ -169,10 +169,10 @@ class _GlobalTaskContext:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._stop_timer() self._stop_timer()
self._manager.global_tasks.remove(self) self._manager.global_tasks.remove(self)
@ -263,8 +263,8 @@ class _ZoneTaskContext:
self._task: asyncio.Task[Any] = task self._task: asyncio.Task[Any] = task
self._state: _State = _State.INIT self._state: _State = _State.INIT
self._time_left: float = timeout self._time_left: float = timeout
self._expiration_time: Optional[float] = None self._expiration_time: float | None = None
self._timeout_handler: Optional[asyncio.Handle] = None self._timeout_handler: asyncio.Handle | None = None
@property @property
def state(self) -> _State: def state(self) -> _State:
@ -283,10 +283,10 @@ class _ZoneTaskContext:
async def __aexit__( async def __aexit__(
self, self,
exc_type: Type[BaseException], exc_type: type[BaseException],
exc_val: BaseException, exc_val: BaseException,
exc_tb: TracebackType, exc_tb: TracebackType,
) -> Optional[bool]: ) -> bool | None:
self._zone.exit_task(self) self._zone.exit_task(self)
self._stop_timer() self._stop_timer()
@ -344,8 +344,8 @@ class _ZoneTimeoutManager:
"""Initialize internal timeout context manager.""" """Initialize internal timeout context manager."""
self._manager: TimeoutManager = manager self._manager: TimeoutManager = manager
self._zone: str = zone self._zone: str = zone
self._tasks: List[_ZoneTaskContext] = [] self._tasks: list[_ZoneTaskContext] = []
self._freezes: List[_ZoneFreezeContext] = [] self._freezes: list[_ZoneFreezeContext] = []
def __repr__(self) -> str: def __repr__(self) -> str:
"""Representation of a zone.""" """Representation of a zone."""
@ -418,9 +418,9 @@ class TimeoutManager:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize TimeoutManager.""" """Initialize TimeoutManager."""
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._zones: Dict[str, _ZoneTimeoutManager] = {} self._zones: dict[str, _ZoneTimeoutManager] = {}
self._globals: List[_GlobalTaskContext] = [] self._globals: list[_GlobalTaskContext] = []
self._freezes: List[_GlobalFreezeContext] = [] self._freezes: list[_GlobalFreezeContext] = []
@property @property
def zones_done(self) -> bool: def zones_done(self) -> bool:
@ -433,17 +433,17 @@ class TimeoutManager:
return not self._freezes return not self._freezes
@property @property
def zones(self) -> Dict[str, _ZoneTimeoutManager]: def zones(self) -> dict[str, _ZoneTimeoutManager]:
"""Return all Zones.""" """Return all Zones."""
return self._zones return self._zones
@property @property
def global_tasks(self) -> List[_GlobalTaskContext]: def global_tasks(self) -> list[_GlobalTaskContext]:
"""Return all global Tasks.""" """Return all global Tasks."""
return self._globals return self._globals
@property @property
def global_freezes(self) -> List[_GlobalFreezeContext]: def global_freezes(self) -> list[_GlobalFreezeContext]:
"""Return all global Freezes.""" """Return all global Freezes."""
return self._freezes return self._freezes
@ -459,12 +459,12 @@ class TimeoutManager:
def async_timeout( def async_timeout(
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0 self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
) -> Union[_ZoneTaskContext, _GlobalTaskContext]: ) -> _ZoneTaskContext | _GlobalTaskContext:
"""Timeout based on a zone. """Timeout based on a zone.
For using as Async Context Manager. For using as Async Context Manager.
""" """
current_task: Optional[asyncio.Task[Any]] = asyncio.current_task() current_task: asyncio.Task[Any] | None = asyncio.current_task()
assert current_task assert current_task
# Global Zone # Global Zone
@ -483,7 +483,7 @@ class TimeoutManager:
def async_freeze( def async_freeze(
self, zone_name: str = ZONE_GLOBAL self, zone_name: str = ZONE_GLOBAL
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]: ) -> _ZoneFreezeContext | _GlobalFreezeContext:
"""Freeze all timer until job is done. """Freeze all timer until job is done.
For using as Async Context Manager. For using as Async Context Manager.
@ -502,7 +502,7 @@ class TimeoutManager:
def freeze( def freeze(
self, zone_name: str = ZONE_GLOBAL self, zone_name: str = ZONE_GLOBAL
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]: ) -> _ZoneFreezeContext | _GlobalFreezeContext:
"""Freeze all timer until job is done. """Freeze all timer until job is done.
For using as Context Manager. For using as Context Manager.

View File

@ -1,6 +1,7 @@
"""Unit system helper class and methods.""" """Unit system helper class and methods."""
from __future__ import annotations
from numbers import Number from numbers import Number
from typing import Dict, Optional
from homeassistant.const import ( from homeassistant.const import (
CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_IMPERIAL,
@ -109,7 +110,7 @@ class UnitSystem:
return temperature_util.convert(temperature, from_unit, self.temperature_unit) return temperature_util.convert(temperature, from_unit, self.temperature_unit)
def length(self, length: Optional[float], from_unit: str) -> float: def length(self, length: float | None, from_unit: str) -> float:
"""Convert the given length to this unit system.""" """Convert the given length to this unit system."""
if not isinstance(length, Number): if not isinstance(length, Number):
raise TypeError(f"{length!s} is not a numeric value.") raise TypeError(f"{length!s} is not a numeric value.")
@ -119,7 +120,7 @@ class UnitSystem:
length, from_unit, self.length_unit length, from_unit, self.length_unit
) )
def pressure(self, pressure: Optional[float], from_unit: str) -> float: def pressure(self, pressure: float | None, from_unit: str) -> float:
"""Convert the given pressure to this unit system.""" """Convert the given pressure to this unit system."""
if not isinstance(pressure, Number): if not isinstance(pressure, Number):
raise TypeError(f"{pressure!s} is not a numeric value.") raise TypeError(f"{pressure!s} is not a numeric value.")
@ -129,7 +130,7 @@ class UnitSystem:
pressure, from_unit, self.pressure_unit pressure, from_unit, self.pressure_unit
) )
def volume(self, volume: Optional[float], from_unit: str) -> float: def volume(self, volume: float | None, from_unit: str) -> float:
"""Convert the given volume to this unit system.""" """Convert the given volume to this unit system."""
if not isinstance(volume, Number): if not isinstance(volume, Number):
raise TypeError(f"{volume!s} is not a numeric value.") raise TypeError(f"{volume!s} is not a numeric value.")
@ -137,7 +138,7 @@ class UnitSystem:
# type ignore: https://github.com/python/mypy/issues/7207 # type ignore: https://github.com/python/mypy/issues/7207
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
def as_dict(self) -> Dict[str, str]: def as_dict(self) -> dict[str, str]:
"""Convert the unit system to a dictionary.""" """Convert the unit system to a dictionary."""
return { return {
LENGTH: self.length_unit, LENGTH: self.length_unit,

View File

@ -1,6 +1,7 @@
"""Deal with YAML input.""" """Deal with YAML input."""
from __future__ import annotations
from typing import Any, Dict, Set from typing import Any
from .objects import Input from .objects import Input
@ -14,14 +15,14 @@ class UndefinedSubstitution(Exception):
self.input = input self.input = input
def extract_inputs(obj: Any) -> Set[str]: def extract_inputs(obj: Any) -> set[str]:
"""Extract input from a structure.""" """Extract input from a structure."""
found: Set[str] = set() found: set[str] = set()
_extract_inputs(obj, found) _extract_inputs(obj, found)
return found return found
def _extract_inputs(obj: Any, found: Set[str]) -> None: def _extract_inputs(obj: Any, found: set[str]) -> None:
"""Extract input from a structure.""" """Extract input from a structure."""
if isinstance(obj, Input): if isinstance(obj, Input):
found.add(obj.name) found.add(obj.name)
@ -38,7 +39,7 @@ def _extract_inputs(obj: Any, found: Set[str]) -> None:
return return
def substitute(obj: Any, substitutions: Dict[str, Any]) -> Any: def substitute(obj: Any, substitutions: dict[str, Any]) -> Any:
"""Substitute values.""" """Substitute values."""
if isinstance(obj, Input): if isinstance(obj, Input):
if obj.name not in substitutions: if obj.name not in substitutions:

View File

@ -1,10 +1,12 @@
"""Custom loader.""" """Custom loader."""
from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
import fnmatch import fnmatch
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union, overload from typing import Any, Dict, Iterator, List, TextIO, TypeVar, Union, overload
import yaml import yaml
@ -27,7 +29,7 @@ class Secrets:
def __init__(self, config_dir: Path): def __init__(self, config_dir: Path):
"""Initialize secrets.""" """Initialize secrets."""
self.config_dir = config_dir self.config_dir = config_dir
self._cache: Dict[Path, Dict[str, str]] = {} self._cache: dict[Path, dict[str, str]] = {}
def get(self, requester_path: str, secret: str) -> str: def get(self, requester_path: str, secret: str) -> str:
"""Return the value of a secret.""" """Return the value of a secret."""
@ -55,7 +57,7 @@ class Secrets:
raise HomeAssistantError(f"Secret {secret} not defined") raise HomeAssistantError(f"Secret {secret} not defined")
def _load_secret_yaml(self, secret_dir: Path) -> Dict[str, str]: def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
"""Load the secrets yaml from path.""" """Load the secrets yaml from path."""
secret_path = secret_dir / SECRET_YAML secret_path = secret_dir / SECRET_YAML
@ -90,7 +92,7 @@ class Secrets:
class SafeLineLoader(yaml.SafeLoader): class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers.""" """Loader class that keeps track of line numbers."""
def __init__(self, stream: Any, secrets: Optional[Secrets] = None) -> None: def __init__(self, stream: Any, secrets: Secrets | None = None) -> None:
"""Initialize a safe line loader.""" """Initialize a safe line loader."""
super().__init__(stream) super().__init__(stream)
self.secrets = secrets self.secrets = secrets
@ -103,7 +105,7 @@ class SafeLineLoader(yaml.SafeLoader):
return node return node
def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE: def load_yaml(fname: str, secrets: Secrets | None = None) -> JSON_TYPE:
"""Load a YAML file.""" """Load a YAML file."""
try: try:
with open(fname, encoding="utf-8") as conf_file: with open(fname, encoding="utf-8") as conf_file:
@ -113,9 +115,7 @@ def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
raise HomeAssistantError(exc) from exc raise HomeAssistantError(exc) from exc
def parse_yaml( def parse_yaml(content: str | TextIO, secrets: Secrets | None = None) -> JSON_TYPE:
content: Union[str, TextIO], secrets: Optional[Secrets] = None
) -> JSON_TYPE:
"""Load a YAML file.""" """Load a YAML file."""
try: try:
# If configuration file is empty YAML returns None # If configuration file is empty YAML returns None
@ -131,14 +131,14 @@ def parse_yaml(
@overload @overload
def _add_reference( def _add_reference(
obj: Union[list, NodeListClass], loader: SafeLineLoader, node: yaml.nodes.Node obj: list | NodeListClass, loader: SafeLineLoader, node: yaml.nodes.Node
) -> NodeListClass: ) -> NodeListClass:
... ...
@overload @overload
def _add_reference( def _add_reference(
obj: Union[str, NodeStrClass], loader: SafeLineLoader, node: yaml.nodes.Node obj: str | NodeStrClass, loader: SafeLineLoader, node: yaml.nodes.Node
) -> NodeStrClass: ) -> NodeStrClass:
... ...
@ -223,7 +223,7 @@ def _include_dir_merge_named_yaml(
def _include_dir_list_yaml( def _include_dir_list_yaml(
loader: SafeLineLoader, node: yaml.nodes.Node loader: SafeLineLoader, node: yaml.nodes.Node
) -> List[JSON_TYPE]: ) -> list[JSON_TYPE]:
"""Load multiple files from directory as a list.""" """Load multiple files from directory as a list."""
loc = os.path.join(os.path.dirname(loader.name), node.value) loc = os.path.join(os.path.dirname(loader.name), node.value)
return [ return [
@ -238,7 +238,7 @@ def _include_dir_merge_list_yaml(
) -> JSON_TYPE: ) -> JSON_TYPE:
"""Load multiple files from directory as a merged list.""" """Load multiple files from directory as a merged list."""
loc: str = os.path.join(os.path.dirname(loader.name), node.value) loc: str = os.path.join(os.path.dirname(loader.name), node.value)
merged_list: List[JSON_TYPE] = [] merged_list: list[JSON_TYPE] = []
for fname in _find_files(loc, "*.yaml"): for fname in _find_files(loc, "*.yaml"):
if os.path.basename(fname) == SECRET_YAML: if os.path.basename(fname) == SECRET_YAML:
continue continue
@ -253,7 +253,7 @@ def _ordered_dict(loader: SafeLineLoader, node: yaml.nodes.MappingNode) -> Order
loader.flatten_mapping(node) loader.flatten_mapping(node)
nodes = loader.construct_pairs(node) nodes = loader.construct_pairs(node)
seen: Dict = {} seen: dict = {}
for (key, _), (child_node, _) in zip(nodes, node.value): for (key, _), (child_node, _) in zip(nodes, node.value):
line = child_node.start_mark.line line = child_node.start_mark.line