Update typing 03 (#48015)

pull/48062/head
Marc Mueller 2021-03-17 21:46:07 +01:00 committed by GitHub
parent 6fb2e63e49
commit fabd73f08b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 417 additions and 379 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
from collections import OrderedDict
from datetime import timedelta
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, Optional, Tuple, cast
import jwt
@ -36,8 +36,8 @@ class InvalidProvider(Exception):
async def auth_manager_from_config(
hass: HomeAssistant,
provider_configs: List[Dict[str, Any]],
module_configs: List[Dict[str, Any]],
provider_configs: list[dict[str, Any]],
module_configs: list[dict[str, Any]],
) -> AuthManager:
"""Initialize an auth manager from config.
@ -87,8 +87,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
self,
handler_key: Any,
*,
context: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> data_entry_flow.FlowHandler:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
@ -97,8 +97,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
) -> Dict[str, Any]:
self, flow: data_entry_flow.FlowHandler, result: dict[str, Any]
) -> dict[str, Any]:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)
@ -157,22 +157,22 @@ class AuthManager:
self.login_flow = AuthManagerFlowManager(hass, self)
@property
def auth_providers(self) -> List[AuthProvider]:
def auth_providers(self) -> list[AuthProvider]:
"""Return a list of available auth providers."""
return list(self._providers.values())
@property
def auth_mfa_modules(self) -> List[MultiFactorAuthModule]:
def auth_mfa_modules(self) -> list[MultiFactorAuthModule]:
"""Return a list of available auth modules."""
return list(self._mfa_modules.values())
def get_auth_provider(
self, provider_type: str, provider_id: Optional[str]
) -> Optional[AuthProvider]:
self, provider_type: str, provider_id: str | None
) -> AuthProvider | None:
"""Return an auth provider, None if not found."""
return self._providers.get((provider_type, provider_id))
def get_auth_providers(self, provider_type: str) -> List[AuthProvider]:
def get_auth_providers(self, provider_type: str) -> list[AuthProvider]:
"""Return a List of auth provider of one type, Empty if not found."""
return [
provider
@ -180,30 +180,30 @@ class AuthManager:
if p_type == provider_type
]
def get_auth_mfa_module(self, module_id: str) -> Optional[MultiFactorAuthModule]:
def get_auth_mfa_module(self, module_id: str) -> MultiFactorAuthModule | None:
"""Return a multi-factor auth module, None if not found."""
return self._mfa_modules.get(module_id)
async def async_get_users(self) -> List[models.User]:
async def async_get_users(self) -> list[models.User]:
"""Retrieve all users."""
return await self._store.async_get_users()
async def async_get_user(self, user_id: str) -> Optional[models.User]:
async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user."""
return await self._store.async_get_user(user_id)
async def async_get_owner(self) -> Optional[models.User]:
async def async_get_owner(self) -> models.User | None:
"""Retrieve the owner."""
users = await self.async_get_users()
return next((user for user in users if user.is_owner), None)
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all groups."""
return await self._store.async_get_group(group_id)
async def async_get_user_by_credentials(
self, credentials: models.Credentials
) -> Optional[models.User]:
) -> models.User | None:
"""Get a user by credential, return None if not found."""
for user in await self.async_get_users():
for creds in user.credentials:
@ -213,7 +213,7 @@ class AuthManager:
return None
async def async_create_system_user(
self, name: str, group_ids: Optional[List[str]] = None
self, name: str, group_ids: list[str] | None = None
) -> models.User:
"""Create a system user."""
user = await self._store.async_create_user(
@ -225,10 +225,10 @@ class AuthManager:
return user
async def async_create_user(
self, name: str, group_ids: Optional[List[str]] = None
self, name: str, group_ids: list[str] | None = None
) -> models.User:
"""Create a user."""
kwargs: Dict[str, Any] = {
kwargs: dict[str, Any] = {
"name": name,
"is_active": True,
"group_ids": group_ids or [],
@ -294,12 +294,12 @@ class AuthManager:
async def async_update_user(
self,
user: models.User,
name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None,
name: str | None = None,
is_active: bool | None = None,
group_ids: list[str] | None = None,
) -> None:
"""Update a user."""
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if group_ids is not None:
@ -362,9 +362,9 @@ class AuthManager:
await module.async_depose_user(user.id)
async def async_get_enabled_mfa(self, user: models.User) -> Dict[str, str]:
async def async_get_enabled_mfa(self, user: models.User) -> dict[str, str]:
"""List enabled mfa modules for user."""
modules: Dict[str, str] = OrderedDict()
modules: dict[str, str] = OrderedDict()
for module_id, module in self._mfa_modules.items():
if await module.async_is_user_setup(user.id):
modules[module_id] = module.name
@ -373,12 +373,12 @@ class AuthManager:
async def async_create_refresh_token(
self,
user: models.User,
client_id: Optional[str] = None,
client_name: Optional[str] = None,
client_icon: Optional[str] = None,
token_type: Optional[str] = None,
client_id: str | None = None,
client_name: str | None = None,
client_icon: str | None = None,
token_type: str | None = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None,
credential: models.Credentials | None = None,
) -> models.RefreshToken:
"""Create a new refresh token for a user."""
if not user.is_active:
@ -432,13 +432,13 @@ class AuthManager:
async def async_get_refresh_token(
self, token_id: str
) -> Optional[models.RefreshToken]:
) -> models.RefreshToken | None:
"""Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token(
self, token: str
) -> Optional[models.RefreshToken]:
) -> models.RefreshToken | None:
"""Get refresh token by token."""
return await self._store.async_get_refresh_token_by_token(token)
@ -450,7 +450,7 @@ class AuthManager:
@callback
def async_create_access_token(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> str:
"""Create a new access token."""
self.async_validate_refresh_token(refresh_token, remote_ip)
@ -471,7 +471,7 @@ class AuthManager:
@callback
def _async_resolve_provider(
self, refresh_token: models.RefreshToken
) -> Optional[AuthProvider]:
) -> AuthProvider | None:
"""Get the auth provider for the given refresh token.
Raises an exception if the expected provider is no longer available or return
@ -492,7 +492,7 @@ class AuthManager:
@callback
def async_validate_refresh_token(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> None:
"""Validate that a refresh token is usable.
@ -504,7 +504,7 @@ class AuthManager:
async def async_validate_access_token(
self, token: str
) -> Optional[models.RefreshToken]:
) -> models.RefreshToken | None:
"""Return refresh token if an access token is valid."""
try:
unverif_claims = jwt.decode(token, verify=False)
@ -535,7 +535,7 @@ class AuthManager:
@callback
def _async_get_auth_provider(
self, credentials: models.Credentials
) -> Optional[AuthProvider]:
) -> AuthProvider | None:
"""Get auth provider from a set of credentials."""
auth_provider_key = (
credentials.auth_provider_type,

View File

@ -1,10 +1,12 @@
"""Storage for auth models."""
from __future__ import annotations
import asyncio
from collections import OrderedDict
from datetime import timedelta
import hmac
from logging import getLogger
from typing import Any, Dict, List, Optional
from typing import Any
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant, callback
@ -34,15 +36,15 @@ class AuthStore:
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the auth store."""
self.hass = hass
self._users: Optional[Dict[str, models.User]] = None
self._groups: Optional[Dict[str, models.Group]] = None
self._perm_lookup: Optional[PermissionLookup] = None
self._users: dict[str, models.User] | None = None
self._groups: dict[str, models.Group] | None = None
self._perm_lookup: PermissionLookup | None = None
self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True
)
self._lock = asyncio.Lock()
async def async_get_groups(self) -> List[models.Group]:
async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
@ -50,7 +52,7 @@ class AuthStore:
return list(self._groups.values())
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
@ -58,7 +60,7 @@ class AuthStore:
return self._groups.get(group_id)
async def async_get_users(self) -> List[models.User]:
async def async_get_users(self) -> list[models.User]:
"""Retrieve all users."""
if self._users is None:
await self._async_load()
@ -66,7 +68,7 @@ class AuthStore:
return list(self._users.values())
async def async_get_user(self, user_id: str) -> Optional[models.User]:
async def async_get_user(self, user_id: str) -> models.User | None:
"""Retrieve a user by id."""
if self._users is None:
await self._async_load()
@ -76,12 +78,12 @@ class AuthStore:
async def async_create_user(
self,
name: Optional[str],
is_owner: Optional[bool] = None,
is_active: Optional[bool] = None,
system_generated: Optional[bool] = None,
credentials: Optional[models.Credentials] = None,
group_ids: Optional[List[str]] = None,
name: str | None,
is_owner: bool | None = None,
is_active: bool | None = None,
system_generated: bool | None = None,
credentials: models.Credentials | None = None,
group_ids: list[str] | None = None,
) -> models.User:
"""Create a new user."""
if self._users is None:
@ -97,7 +99,7 @@ class AuthStore:
raise ValueError(f"Invalid group specified {group_id}")
groups.append(group)
kwargs: Dict[str, Any] = {
kwargs: dict[str, Any] = {
"name": name,
# Until we get group management, we just put everyone in the
# same group.
@ -146,9 +148,9 @@ class AuthStore:
async def async_update_user(
self,
user: models.User,
name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None,
name: str | None = None,
is_active: bool | None = None,
group_ids: list[str] | None = None,
) -> None:
"""Update a user."""
assert self._groups is not None
@ -203,15 +205,15 @@ class AuthStore:
async def async_create_refresh_token(
self,
user: models.User,
client_id: Optional[str] = None,
client_name: Optional[str] = None,
client_icon: Optional[str] = None,
client_id: str | None = None,
client_name: str | None = None,
client_icon: str | None = None,
token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None,
credential: models.Credentials | None = None,
) -> models.RefreshToken:
"""Create a new token for a user."""
kwargs: Dict[str, Any] = {
kwargs: dict[str, Any] = {
"user": user,
"client_id": client_id,
"token_type": token_type,
@ -244,7 +246,7 @@ class AuthStore:
async def async_get_refresh_token(
self, token_id: str
) -> Optional[models.RefreshToken]:
) -> models.RefreshToken | None:
"""Get refresh token by id."""
if self._users is None:
await self._async_load()
@ -259,7 +261,7 @@ class AuthStore:
async def async_get_refresh_token_by_token(
self, token: str
) -> Optional[models.RefreshToken]:
) -> models.RefreshToken | None:
"""Get refresh token by token."""
if self._users is None:
await self._async_load()
@ -276,7 +278,7 @@ class AuthStore:
@callback
def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
) -> None:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
@ -309,9 +311,9 @@ class AuthStore:
self._set_defaults()
return
users: Dict[str, models.User] = OrderedDict()
groups: Dict[str, models.Group] = OrderedDict()
credentials: Dict[str, models.Credentials] = OrderedDict()
users: dict[str, models.User] = OrderedDict()
groups: dict[str, models.Group] = OrderedDict()
credentials: dict[str, models.Credentials] = OrderedDict()
# Soft-migrating data as we load. We are going to make sure we have a
# read only group and an admin group. There are two states that we can
@ -328,7 +330,7 @@ class AuthStore:
# was added.
for group_dict in data.get("groups", []):
policy: Optional[PolicyType] = None
policy: PolicyType | None = None
if group_dict["id"] == GROUP_ID_ADMIN:
has_admin_group = True
@ -489,7 +491,7 @@ class AuthStore:
self._store.async_delay_save(self._data_to_save, 1)
@callback
def _data_to_save(self) -> Dict:
def _data_to_save(self) -> dict:
"""Return the data to store."""
assert self._users is not None
assert self._groups is not None
@ -508,7 +510,7 @@ class AuthStore:
groups = []
for group in self._groups.values():
g_dict: Dict[str, Any] = {
g_dict: dict[str, Any] = {
"id": group.id,
# Name not read for sys groups. Kept here for backwards compat
"name": group.name,
@ -567,7 +569,7 @@ class AuthStore:
"""Set default values for auth store."""
self._users = OrderedDict()
groups: Dict[str, models.Group] = OrderedDict()
groups: dict[str, models.Group] = OrderedDict()
admin_group = _system_admin_group()
groups[admin_group.id] = admin_group
user_group = _system_user_group()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import importlib
import logging
import types
from typing import Any, Dict, Optional
from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -38,7 +38,7 @@ class MultiFactorAuthModule:
DEFAULT_TITLE = "Unnamed auth module"
MAX_RETRY_TIME = 3
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize an auth module."""
self.hass = hass
self.config = config
@ -87,7 +87,7 @@ class MultiFactorAuthModule:
"""Return whether user is setup."""
raise NotImplementedError
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed."""
raise NotImplementedError
@ -104,14 +104,14 @@ class SetupFlow(data_entry_flow.FlowHandler):
self._user_id = user_id
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None.
Return self.async_create_entry(data={'result': result}) if finish.
"""
errors: Dict[str, str] = {}
errors: dict[str, str] = {}
if user_input:
result = await self._auth_module.async_setup_user(self._user_id, user_input)
@ -125,7 +125,7 @@ class SetupFlow(data_entry_flow.FlowHandler):
async def auth_mfa_module_from_config(
hass: HomeAssistant, config: Dict[str, Any]
hass: HomeAssistant, config: dict[str, Any]
) -> MultiFactorAuthModule:
"""Initialize an auth module from a config."""
module_name = config[CONF_TYPE]

View File

@ -1,5 +1,7 @@
"""Example auth module."""
from typing import Any, Dict
from __future__ import annotations
from typing import Any
import voluptuous as vol
@ -28,7 +30,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Insecure Personal Identify Number"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store."""
super().__init__(hass, config)
self._data = config["data"]
@ -80,7 +82,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
return True
return False
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed."""
for data in self._data:
if data["user_id"] == user_id:

View File

@ -2,10 +2,12 @@
Sending HOTP through notify service
"""
from __future__ import annotations
import asyncio
from collections import OrderedDict
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict
import attr
import voluptuous as vol
@ -79,8 +81,8 @@ class NotifySetting:
secret: str = attr.ib(factory=_generate_secret) # not persistent
counter: int = attr.ib(factory=_generate_random) # not persistent
notify_service: Optional[str] = attr.ib(default=None)
target: Optional[str] = attr.ib(default=None)
notify_service: str | None = attr.ib(default=None)
target: str | None = attr.ib(default=None)
_UsersDict = Dict[str, NotifySetting]
@ -92,10 +94,10 @@ class NotifyAuthModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Notify One-Time Password"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store."""
super().__init__(hass, config)
self._user_settings: Optional[_UsersDict] = None
self._user_settings: _UsersDict | None = None
self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True
)
@ -146,7 +148,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
)
@callback
def aync_get_available_notify_services(self) -> List[str]:
def aync_get_available_notify_services(self) -> list[str]:
"""Return list of notify services."""
unordered_services = set()
@ -198,7 +200,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
return user_id in self._user_settings
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed."""
if self._user_settings is None:
await self._async_load()
@ -258,7 +260,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
)
async def async_notify(
self, code: str, notify_service: str, target: Optional[str] = None
self, code: str, notify_service: str, target: str | None = None
) -> None:
"""Send code by notify service."""
data = {"message": self._message_template.format(code)}
@ -276,23 +278,23 @@ class NotifySetupFlow(SetupFlow):
auth_module: NotifyAuthModule,
setup_schema: vol.Schema,
user_id: str,
available_notify_services: List[str],
available_notify_services: list[str],
) -> None:
"""Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint
self._auth_module: NotifyAuthModule = auth_module
self._available_notify_services = available_notify_services
self._secret: Optional[str] = None
self._count: Optional[int] = None
self._notify_service: Optional[str] = None
self._target: Optional[str] = None
self._secret: str | None = None
self._count: int | None = None
self._notify_service: str | None = None
self._target: str | None = None
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Let user select available notify services."""
errors: Dict[str, str] = {}
errors: dict[str, str] = {}
hass = self._auth_module.hass
if user_input:
@ -306,7 +308,7 @@ class NotifySetupFlow(SetupFlow):
if not self._available_notify_services:
return self.async_abort(reason="no_available_service")
schema: Dict[str, Any] = OrderedDict()
schema: dict[str, Any] = OrderedDict()
schema["notify_service"] = vol.In(self._available_notify_services)
schema["target"] = vol.Optional(str)
@ -315,10 +317,10 @@ class NotifySetupFlow(SetupFlow):
)
async def async_step_setup(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Verify user can receive one-time password."""
errors: Dict[str, str] = {}
errors: dict[str, str] = {}
hass = self._auth_module.hass
if user_input:

View File

@ -1,7 +1,9 @@
"""Time-based One Time Password auth module."""
from __future__ import annotations
import asyncio
from io import BytesIO
from typing import Any, Dict, Optional, Tuple
from typing import Any
import voluptuous as vol
@ -50,7 +52,7 @@ def _generate_qr_code(data: str) -> str:
)
def _generate_secret_and_qr_code(username: str) -> Tuple[str, str, str]:
def _generate_secret_and_qr_code(username: str) -> tuple[str, str, str]:
"""Generate a secret, url, and QR code."""
import pyotp # pylint: disable=import-outside-toplevel
@ -69,10 +71,10 @@ class TotpAuthModule(MultiFactorAuthModule):
DEFAULT_TITLE = "Time-based One Time Password"
MAX_RETRY_TIME = 5
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Initialize the user data store."""
super().__init__(hass, config)
self._users: Optional[Dict[str, str]] = None
self._users: dict[str, str] | None = None
self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True
)
@ -100,7 +102,7 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Save data."""
await self._user_store.async_save({STORAGE_USERS: self._users})
def _add_ota_secret(self, user_id: str, secret: Optional[str] = None) -> str:
def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
"""Create a ota_secret for user."""
import pyotp # pylint: disable=import-outside-toplevel
@ -145,7 +147,7 @@ class TotpAuthModule(MultiFactorAuthModule):
return user_id in self._users # type: ignore
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
"""Return True if validation passed."""
if self._users is None:
await self._async_load()
@ -181,13 +183,13 @@ class TotpSetupFlow(SetupFlow):
# to fix typing complaint
self._auth_module: TotpAuthModule = auth_module
self._user = user
self._ota_secret: Optional[str] = None
self._ota_secret: str | None = None
self._url = None # type Optional[str]
self._image = None # type Optional[str]
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None.
@ -195,7 +197,7 @@ class TotpSetupFlow(SetupFlow):
"""
import pyotp # pylint: disable=import-outside-toplevel
errors: Dict[str, str] = {}
errors: dict[str, str] = {}
if user_input:
verified = await self.hass.async_add_executor_job(

View File

@ -1,7 +1,9 @@
"""Auth models."""
from __future__ import annotations
from datetime import datetime, timedelta
import secrets
from typing import Dict, List, NamedTuple, Optional
from typing import NamedTuple
import uuid
import attr
@ -21,7 +23,7 @@ TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
class Group:
"""A group."""
name: Optional[str] = attr.ib()
name: str | None = attr.ib()
policy: perm_mdl.PolicyType = attr.ib()
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
system_generated: bool = attr.ib(default=False)
@ -31,24 +33,24 @@ class Group:
class User:
"""A user."""
name: Optional[str] = attr.ib()
name: str | None = attr.ib()
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
is_owner: bool = attr.ib(default=False)
is_active: bool = attr.ib(default=False)
system_generated: bool = attr.ib(default=False)
groups: List[Group] = attr.ib(factory=list, eq=False, order=False)
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
# List of credentials of a user.
credentials: List["Credentials"] = attr.ib(factory=list, eq=False, order=False)
credentials: list["Credentials"] = attr.ib(factory=list, eq=False, order=False)
# Tokens associated with a user.
refresh_tokens: Dict[str, "RefreshToken"] = attr.ib(
refresh_tokens: dict[str, "RefreshToken"] = attr.ib(
factory=dict, eq=False, order=False
)
_permissions: Optional[perm_mdl.PolicyPermissions] = attr.ib(
_permissions: perm_mdl.PolicyPermissions | None = attr.ib(
init=False,
eq=False,
order=False,
@ -89,10 +91,10 @@ class RefreshToken:
"""RefreshToken for a user to grant new access tokens."""
user: User = attr.ib()
client_id: Optional[str] = attr.ib()
client_id: str | None = attr.ib()
access_token_expiration: timedelta = attr.ib()
client_name: Optional[str] = attr.ib(default=None)
client_icon: Optional[str] = attr.ib(default=None)
client_name: str | None = attr.ib(default=None)
client_icon: str | None = attr.ib(default=None)
token_type: str = attr.ib(
default=TOKEN_TYPE_NORMAL,
validator=attr.validators.in_(
@ -104,12 +106,12 @@ class RefreshToken:
token: str = attr.ib(factory=lambda: secrets.token_hex(64))
jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64))
last_used_at: Optional[datetime] = attr.ib(default=None)
last_used_ip: Optional[str] = attr.ib(default=None)
last_used_at: datetime | None = attr.ib(default=None)
last_used_ip: str | None = attr.ib(default=None)
credential: Optional["Credentials"] = attr.ib(default=None)
credential: "Credentials" | None = attr.ib(default=None)
version: Optional[str] = attr.ib(default=__version__)
version: str | None = attr.ib(default=__version__)
@attr.s(slots=True)
@ -117,7 +119,7 @@ class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type: str = attr.ib()
auth_provider_id: Optional[str] = attr.ib()
auth_provider_id: str | None = attr.ib()
# Allow the auth provider to store data to represent their auth.
data: dict = attr.ib()
@ -129,5 +131,5 @@ class Credentials:
class UserMeta(NamedTuple):
"""User metadata."""
name: Optional[str]
name: str | None
is_active: bool

View File

@ -1,6 +1,8 @@
"""Permissions for Home Assistant."""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional
from typing import Any, Callable
import voluptuous as vol
@ -19,7 +21,7 @@ _LOGGER = logging.getLogger(__name__)
class AbstractPermissions:
"""Default permissions class."""
_cached_entity_func: Optional[Callable[[str, str], bool]] = None
_cached_entity_func: Callable[[str, str], bool] | None = None
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""

View File

@ -1,6 +1,8 @@
"""Entity permissions."""
from __future__ import annotations
from collections import OrderedDict
from typing import Callable, Optional
from typing import Callable
import voluptuous as vol
@ -43,14 +45,14 @@ ENTITY_POLICY_SCHEMA = vol.Any(
def _lookup_domain(
perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]:
) -> ValueType | None:
"""Look up entity permissions by domain."""
return domains_dict.get(entity_id.split(".", 1)[0])
def _lookup_area(
perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]:
) -> ValueType | None:
"""Look up entity permissions by area."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
@ -67,7 +69,7 @@ def _lookup_area(
def _lookup_device(
perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]:
) -> ValueType | None:
"""Look up entity permissions by device."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
@ -79,7 +81,7 @@ def _lookup_device(
def _lookup_entity_id(
perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]:
) -> ValueType | None:
"""Look up entity permission by entity id."""
return entities_dict.get(entity_id)

View File

@ -1,13 +1,15 @@
"""Merging of policies."""
from typing import Dict, List, Set, cast
from __future__ import annotations
from typing import cast
from .types import CategoryType, PolicyType
def merge_policies(policies: List[PolicyType]) -> PolicyType:
def merge_policies(policies: list[PolicyType]) -> PolicyType:
"""Merge policies."""
new_policy: Dict[str, CategoryType] = {}
seen: Set[str] = set()
new_policy: dict[str, CategoryType] = {}
seen: set[str] = set()
for policy in policies:
for category in policy:
if category in seen:
@ -20,7 +22,7 @@ def merge_policies(policies: List[PolicyType]) -> PolicyType:
return new_policy
def _merge_policies(sources: List[CategoryType]) -> CategoryType:
def _merge_policies(sources: list[CategoryType]) -> CategoryType:
"""Merge a policy."""
# When merging policies, the most permissive wins.
# This means we order it like this:
@ -34,7 +36,7 @@ def _merge_policies(sources: List[CategoryType]) -> CategoryType:
# merge each key in the source.
policy: CategoryType = None
seen: Set[str] = set()
seen: set[str] = set()
for source in sources:
if source is None:
continue

View File

@ -1,6 +1,8 @@
"""Helpers to deal with permissions."""
from __future__ import annotations
from functools import wraps
from typing import Callable, Dict, List, Optional, cast
from typing import Callable, Dict, Optional, cast
from .const import SUBCAT_ALL
from .models import PermissionLookup
@ -45,7 +47,7 @@ def compile_policy(
assert isinstance(policy, dict)
funcs: List[Callable[[str, str], Optional[bool]]] = []
funcs: list[Callable[[str, str], bool | None]] = []
for key, lookup_func in subcategories.items():
lookup_value = policy.get(key)
@ -80,10 +82,10 @@ def compile_policy(
def _gen_dict_test_func(
perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
) -> Callable[[str, str], Optional[bool]]:
) -> Callable[[str, str], bool | None]:
"""Generate a lookup function."""
def test_value(object_id: str, key: str) -> Optional[bool]:
def test_value(object_id: str, key: str) -> bool | None:
"""Test if permission is allowed based on the keys."""
schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import importlib
import logging
import types
from typing import Any, Dict, List, Optional
from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -42,7 +42,7 @@ class AuthProvider:
DEFAULT_TITLE = "Unnamed auth provider"
def __init__(
self, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
self, hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> None:
"""Initialize an auth provider."""
self.hass = hass
@ -50,7 +50,7 @@ class AuthProvider:
self.config = config
@property
def id(self) -> Optional[str]:
def id(self) -> str | None:
"""Return id of the auth provider.
Optional, can be None.
@ -72,7 +72,7 @@ class AuthProvider:
"""Return whether multi-factor auth supported by the auth provider."""
return True
async def async_credentials(self) -> List[Credentials]:
async def async_credentials(self) -> list[Credentials]:
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
@ -86,7 +86,7 @@ class AuthProvider:
]
@callback
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
def async_create_credentials(self, data: dict[str, str]) -> Credentials:
"""Create credentials."""
return Credentials(
auth_provider_type=self.type, auth_provider_id=self.id, data=data
@ -94,7 +94,7 @@ class AuthProvider:
# Implement by extending class
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
@ -102,7 +102,7 @@ class AuthProvider:
raise NotImplementedError
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
raise NotImplementedError
@ -121,7 +121,7 @@ class AuthProvider:
@callback
def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
self, refresh_token: RefreshToken, remote_ip: str | None = None
) -> None:
"""Verify a refresh token is still valid.
@ -131,7 +131,7 @@ class AuthProvider:
async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
) -> AuthProvider:
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
@ -188,17 +188,17 @@ class LoginFlow(data_entry_flow.FlowHandler):
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._auth_module_id: Optional[str] = None
self._auth_module_id: str | None = None
self._auth_manager = auth_provider.hass.auth
self.available_mfa_modules: Dict[str, str] = {}
self.available_mfa_modules: dict[str, str] = {}
self.created_at = dt_util.utcnow()
self.invalid_mfa_times = 0
self.user: Optional[User] = None
self.credential: Optional[Credentials] = None
self.user: User | None = None
self.credential: Credentials | None = None
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input is None.
@ -207,8 +207,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
raise NotImplementedError
async def async_step_select_mfa_module(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of select mfa module."""
errors = {}
@ -232,8 +232,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
)
async def async_step_mfa(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of mfa validation."""
assert self.credential
assert self.user
@ -273,7 +273,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
if not errors:
return await self.async_finish(self.credential)
description_placeholders: Dict[str, Optional[str]] = {
description_placeholders: dict[str, str | None] = {
"mfa_module_name": auth_module.name,
"mfa_module_id": auth_module.id,
}
@ -285,6 +285,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors=errors,
)
async def async_finish(self, flow_result: Any) -> Dict:
async def async_finish(self, flow_result: Any) -> dict:
"""Handle the pass of login flow."""
return self.async_create_entry(title=self._auth_provider.name, data=flow_result)

View File

@ -1,10 +1,11 @@
"""Auth provider that validates credentials via an external command."""
from __future__ import annotations
import asyncio.subprocess
import collections
import logging
import os
from typing import Any, Dict, Optional, cast
from typing import Any, cast
import voluptuous as vol
@ -51,9 +52,9 @@ class CommandLineAuthProvider(AuthProvider):
attributes provided by external programs.
"""
super().__init__(*args, **kwargs)
self._user_meta: Dict[str, Dict[str, Any]] = {}
self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: Optional[dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login."""
return CommandLineLoginFlow(self)
@ -82,7 +83,7 @@ class CommandLineAuthProvider(AuthProvider):
raise InvalidAuthError
if self.config[CONF_META]:
meta: Dict[str, str] = {}
meta: dict[str, str] = {}
for _line in stdout.splitlines():
try:
line = _line.decode().lstrip()
@ -99,7 +100,7 @@ class CommandLineAuthProvider(AuthProvider):
self._user_meta[username] = meta
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result["username"]
@ -125,8 +126,8 @@ class CommandLineLoginFlow(LoginFlow):
"""Handler for the login flow."""
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of the form."""
errors = {}
@ -143,7 +144,7 @@ class CommandLineLoginFlow(LoginFlow):
user_input.pop("password")
return await self.async_finish(user_input)
schema: Dict[str, type] = collections.OrderedDict()
schema: dict[str, type] = collections.OrderedDict()
schema["username"] = str
schema["password"] = str

View File

@ -5,7 +5,7 @@ import asyncio
import base64
from collections import OrderedDict
import logging
from typing import Any, Dict, List, Optional, Set, cast
from typing import Any, cast
import bcrypt
import voluptuous as vol
@ -21,7 +21,7 @@ STORAGE_VERSION = 1
STORAGE_KEY = "auth_provider.homeassistant"
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
def _disallow_id(conf: dict[str, Any]) -> dict[str, Any]:
"""Disallow ID in config."""
if CONF_ID in conf:
raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
@ -62,7 +62,7 @@ class Data:
self._store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True
)
self._data: Optional[Dict[str, Any]] = None
self._data: dict[str, Any] | None = None
# Legacy mode will allow usernames to start/end with whitespace
# and will compare usernames case-insensitive.
# Remove in 2020 or when we launch 1.0.
@ -83,7 +83,7 @@ class Data:
if data is None:
data = {"users": []}
seen: Set[str] = set()
seen: set[str] = set()
for user in data["users"]:
username = user["username"]
@ -121,7 +121,7 @@ class Data:
self._data = data
@property
def users(self) -> List[Dict[str, str]]:
def users(self) -> list[dict[str, str]]:
"""Return users."""
return self._data["users"] # type: ignore
@ -220,7 +220,7 @@ class HassAuthProvider(AuthProvider):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Home Assistant auth provider."""
super().__init__(*args, **kwargs)
self.data: Optional[Data] = None
self.data: Data | None = None
self._init_lock = asyncio.Lock()
async def async_initialize(self) -> None:
@ -233,7 +233,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load()
self.data = data
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login."""
return HassLoginFlow(self)
@ -277,7 +277,7 @@ class HassAuthProvider(AuthProvider):
await self.data.async_save()
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
if self.data is None:
@ -318,8 +318,8 @@ class HassLoginFlow(LoginFlow):
"""Handler for the login flow."""
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of the form."""
errors = {}
@ -335,7 +335,7 @@ class HassLoginFlow(LoginFlow):
user_input.pop("password")
return await self.async_finish(user_input)
schema: Dict[str, type] = OrderedDict()
schema: dict[str, type] = OrderedDict()
schema["username"] = str
schema["password"] = str

View File

@ -1,7 +1,9 @@
"""Example auth provider."""
from __future__ import annotations
from collections import OrderedDict
import hmac
from typing import Any, Dict, Optional, cast
from typing import Any, cast
import voluptuous as vol
@ -33,7 +35,7 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login."""
return ExampleLoginFlow(self)
@ -60,7 +62,7 @@ class ExampleAuthProvider(AuthProvider):
raise InvalidAuthError
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
username = flow_result["username"]
@ -94,8 +96,8 @@ class ExampleLoginFlow(LoginFlow):
"""Handler for the login flow."""
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of the form."""
errors = {}
@ -111,7 +113,7 @@ class ExampleLoginFlow(LoginFlow):
user_input.pop("password")
return await self.async_finish(user_input)
schema: Dict[str, type] = OrderedDict()
schema: dict[str, type] = OrderedDict()
schema["username"] = str
schema["password"] = str

View File

@ -3,8 +3,10 @@ Support Legacy API password auth provider.
It will be removed when auth system production ready
"""
from __future__ import annotations
import hmac
from typing import Any, Dict, Optional, cast
from typing import Any, cast
import voluptuous as vol
@ -40,7 +42,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
"""Return api_password."""
return str(self.config[CONF_API_PASSWORD])
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login."""
return LegacyLoginFlow(self)
@ -55,7 +57,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
raise InvalidAuthError
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Return credentials for this login."""
credentials = await self.async_credentials()
@ -79,8 +81,8 @@ class LegacyLoginFlow(LoginFlow):
"""Handler for the login flow."""
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of the form."""
errors = {}

View File

@ -3,6 +3,8 @@
It shows list of users if access from trusted network.
Abort login flow if not access from trusted network.
"""
from __future__ import annotations
from ipaddress import (
IPv4Address,
IPv4Network,
@ -11,7 +13,7 @@ from ipaddress import (
ip_address,
ip_network,
)
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Union, cast
import voluptuous as vol
@ -68,12 +70,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
DEFAULT_TITLE = "Trusted Networks"
@property
def trusted_networks(self) -> List[IPNetwork]:
def trusted_networks(self) -> list[IPNetwork]:
"""Return trusted networks."""
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
@property
def trusted_users(self) -> Dict[IPNetwork, Any]:
def trusted_users(self) -> dict[IPNetwork, Any]:
"""Return trusted users per network."""
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
@ -82,7 +84,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA."""
return False
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
async def async_login_flow(self, context: dict | None) -> LoginFlow:
"""Return a flow to login."""
assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address"))
@ -125,7 +127,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
)
async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]
self, flow_result: dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result."""
user_id = flow_result["user"]
@ -169,7 +171,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
@callback
def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
self, refresh_token: RefreshToken, remote_ip: str | None = None
) -> None:
"""Verify a refresh token is still valid."""
if remote_ip is None:
@ -186,7 +188,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
self,
auth_provider: TrustedNetworksAuthProvider,
ip_addr: IPAddress,
available_users: Dict[str, Optional[str]],
available_users: dict[str, str | None],
allow_bypass_login: bool,
) -> None:
"""Initialize the login flow."""
@ -196,8 +198,8 @@ class TrustedNetworksLoginFlow(LoginFlow):
self._allow_bypass_login = allow_bypass_login
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
self, user_input: dict[str, str] | None = None
) -> dict[str, Any]:
"""Handle the step of the form."""
try:
cast(

View File

@ -1,11 +1,13 @@
"""Home Assistant command line scripts."""
from __future__ import annotations
import argparse
import asyncio
import importlib
import logging
import os
import sys
from typing import List, Optional, Sequence, Text
from typing import Sequence
from homeassistant import runner
from homeassistant.bootstrap import async_mount_local_lib_path
@ -16,7 +18,7 @@ from homeassistant.util.package import install_package, is_installed, is_virtual
# mypy: allow-untyped-defs, no-warn-return-any
def run(args: List) -> int:
def run(args: list) -> int:
"""Run a script."""
scripts = []
path = os.path.dirname(__file__)
@ -65,7 +67,7 @@ def run(args: List) -> int:
return script.run(args[1:]) # type: ignore
def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str:
def extract_config_dir(args: Sequence[str] | None = None) -> str:
"""Extract the config dir from the arguments or get the default."""
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-c", "--config", default=None)

View File

@ -1,4 +1,6 @@
"""Script to run benchmarks."""
from __future__ import annotations
import argparse
import asyncio
import collections
@ -7,7 +9,7 @@ from datetime import datetime
import json
import logging
from timeit import default_timer as timer
from typing import Callable, Dict, TypeVar
from typing import Callable, TypeVar
from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP
@ -21,7 +23,7 @@ from homeassistant.util import dt as dt_util
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
BENCHMARKS: Dict[str, Callable] = {}
BENCHMARKS: dict[str, Callable] = {}
def run(args):

View File

@ -1,4 +1,6 @@
"""Script to check the configuration file."""
from __future__ import annotations
import argparse
import asyncio
from collections import OrderedDict
@ -6,7 +8,7 @@ from collections.abc import Mapping, Sequence
from glob import glob
import logging
import os
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable
from unittest.mock import patch
from homeassistant import core
@ -22,13 +24,13 @@ REQUIREMENTS = ("colorlog==4.7.2",)
_LOGGER = logging.getLogger(__name__)
# pylint: disable=protected-access
MOCKS: Dict[str, Tuple[str, Callable]] = {
MOCKS: dict[str, tuple[str, Callable]] = {
"load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml),
"load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml),
"secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml),
}
PATCHES: Dict[str, Any] = {}
PATCHES: dict[str, Any] = {}
C_HEAD = "bold"
ERROR_STR = "General Errors"
@ -48,7 +50,7 @@ def color(the_color, *args, reset=None):
raise ValueError(f"Invalid color {k!s} in {the_color}") from k
def run(script_args: List) -> int:
def run(script_args: list) -> int:
"""Handle check config commandline script."""
parser = argparse.ArgumentParser(description="Check Home Assistant configuration.")
parser.add_argument("--script", choices=["check_config"])
@ -83,7 +85,7 @@ def run(script_args: List) -> int:
res = check(config_dir, args.secrets)
domain_info: List[str] = []
domain_info: list[str] = []
if args.info:
domain_info = args.info.split(",")
@ -123,7 +125,7 @@ def run(script_args: List) -> int:
dump_dict(res["components"].get(domain))
if args.secrets:
flatsecret: Dict[str, str] = {}
flatsecret: dict[str, str] = {}
for sfn, sdict in res["secret_cache"].items():
sss = []
@ -149,7 +151,7 @@ def run(script_args: List) -> int:
def check(config_dir, secrets=False):
"""Perform a check by mocking hass load functions."""
logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL)
res: Dict[str, Any] = {
res: dict[str, Any] = {
"yaml_files": OrderedDict(), # yaml_files loaded
"secrets": OrderedDict(), # secret cache and secrets loaded
"except": OrderedDict(), # exceptions raised (with config)

View File

@ -1,4 +1,6 @@
"""Helper methods for various modules."""
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import enum
@ -9,16 +11,7 @@ import socket
import string
import threading
from types import MappingProxyType
from typing import (
Any,
Callable,
Coroutine,
Iterable,
KeysView,
Optional,
TypeVar,
Union,
)
from typing import Any, Callable, Coroutine, Iterable, KeysView, TypeVar
import slugify as unicode_slug
@ -106,8 +99,8 @@ def repr_helper(inp: Any) -> str:
def convert(
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
) -> Optional[U]:
value: T | None, to_type: Callable[[T], U], default: U | None = None
) -> U | None:
"""Convert value to to_type, returns default if fails."""
try:
return default if value is None else to_type(value)
@ -117,7 +110,7 @@ def convert(
def ensure_unique_string(
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
preferred_string: str, current_strings: Iterable[str] | KeysView[str]
) -> str:
"""Return a string that is not present in current_strings.
@ -213,7 +206,7 @@ class Throttle:
"""
def __init__(
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
self, min_time: timedelta, limit_no_throttle: timedelta | None = None
) -> None:
"""Initialize the throttle."""
self.min_time = min_time
@ -253,7 +246,7 @@ class Throttle:
)
@wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
def wrapper(*args: Any, **kwargs: Any) -> Callable | Coroutine:
"""Wrap that allows wrapped to be called only once per min_time.
If we cannot acquire the lock, it is running so return None.

View File

@ -1,7 +1,9 @@
"""Utilities to help with aiohttp."""
from __future__ import annotations
import io
import json
from typing import Any, Dict, Optional
from typing import Any
from urllib.parse import parse_qsl
from multidict import CIMultiDict, MultiDict
@ -26,7 +28,7 @@ class MockStreamReader:
class MockRequest:
"""Mock an aiohttp request."""
mock_source: Optional[str] = None
mock_source: str | None = None
def __init__(
self,
@ -34,8 +36,8 @@ class MockRequest:
mock_source: str,
method: str = "GET",
status: int = HTTP_OK,
headers: Optional[Dict[str, str]] = None,
query_string: Optional[str] = None,
headers: dict[str, str] | None = None,
query_string: str | None = None,
url: str = "",
) -> None:
"""Initialize a request."""

View File

@ -1,7 +1,8 @@
"""Color util methods."""
from __future__ import annotations
import colorsys
import math
from typing import List, Optional, Tuple
import attr
@ -183,7 +184,7 @@ class GamutType:
blue: XYPoint = attr.ib()
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
def color_name_to_rgb(color_name: str) -> tuple[int, int, int]:
"""Convert color name to RGB hex value."""
# COLORS map has no spaces in it, so make the color_name have no
# spaces in it as well for matching purposes
@ -198,8 +199,8 @@ def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
def color_RGB_to_xy(
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
) -> Tuple[float, float]:
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
) -> tuple[float, float]:
"""Convert from RGB color to XY color."""
return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2]
@ -208,8 +209,8 @@ def color_RGB_to_xy(
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
# License: Code is given as is. Use at your own risk and discretion.
def color_RGB_to_xy_brightness(
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
) -> Tuple[float, float, int]:
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
) -> tuple[float, float, int]:
"""Convert from RGB color to XY color."""
if iR + iG + iB == 0:
return 0.0, 0.0, 0
@ -248,8 +249,8 @@ def color_RGB_to_xy_brightness(
def color_xy_to_RGB(
vX: float, vY: float, Gamut: Optional[GamutType] = None
) -> Tuple[int, int, int]:
vX: float, vY: float, Gamut: GamutType | None = None
) -> tuple[int, int, int]:
"""Convert from XY to a normalized RGB."""
return color_xy_brightness_to_RGB(vX, vY, 255, Gamut)
@ -257,8 +258,8 @@ def color_xy_to_RGB(
# Converted to Python from Obj-C, original source from:
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
def color_xy_brightness_to_RGB(
vX: float, vY: float, ibrightness: int, Gamut: Optional[GamutType] = None
) -> Tuple[int, int, int]:
vX: float, vY: float, ibrightness: int, Gamut: GamutType | None = None
) -> tuple[int, int, int]:
"""Convert from XYZ to RGB."""
if Gamut:
if not check_point_in_lamps_reach((vX, vY), Gamut):
@ -304,7 +305,7 @@ def color_xy_brightness_to_RGB(
return (ir, ig, ib)
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> tuple[int, int, int]:
"""Convert a hsb into its rgb representation."""
if fS == 0.0:
fV = int(fB * 255)
@ -345,7 +346,7 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
return (r, g, b)
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> tuple[float, float, float]:
"""Convert an rgb color to its hsv representation.
Hue is scaled 0-360
@ -356,12 +357,12 @@ def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, flo
return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3)
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> tuple[float, float]:
"""Convert an rgb color to its hs representation."""
return color_RGB_to_hsv(iR, iG, iB)[:2]
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> tuple[int, int, int]:
"""Convert an hsv color into its rgb representation.
Hue is scaled 0-360
@ -372,27 +373,27 @@ def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255))
def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
def color_hs_to_RGB(iH: float, iS: float) -> tuple[int, int, int]:
"""Convert an hsv color into its rgb representation."""
return color_hsv_to_RGB(iH, iS, 100)
def color_xy_to_hs(
vX: float, vY: float, Gamut: Optional[GamutType] = None
) -> Tuple[float, float]:
vX: float, vY: float, Gamut: GamutType | None = None
) -> tuple[float, float]:
"""Convert an xy color to its hs representation."""
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut))
return h, s
def color_hs_to_xy(
iH: float, iS: float, Gamut: Optional[GamutType] = None
) -> Tuple[float, float]:
iH: float, iS: float, Gamut: GamutType | None = None
) -> tuple[float, float]:
"""Convert an hs color to its xy representation."""
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut)
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
def _match_max_scale(input_colors: tuple, output_colors: tuple) -> tuple:
"""Match the maximum value of the output to the input."""
max_in = max(input_colors)
max_out = max(output_colors)
@ -403,7 +404,7 @@ def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
return tuple(int(round(i * factor)) for i in output_colors)
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
def color_rgb_to_rgbw(r: int, g: int, b: int) -> tuple[int, int, int, int]:
"""Convert an rgb color to an rgbw representation."""
# Calculate the white channel as the minimum of input rgb channels.
# Subtract the white portion from the remaining rgb channels.
@ -415,7 +416,7 @@ def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
return _match_max_scale((r, g, b), rgbw) # type: ignore
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> tuple[int, int, int]:
"""Convert an rgbw color to an rgb representation."""
# Add the white channel back into the rgb channels.
rgb = (r + w, g + w, b + w)
@ -430,7 +431,7 @@ def color_rgb_to_hex(r: int, g: int, b: int) -> str:
return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b))
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
def rgb_hex_to_rgb_list(hex_string: str) -> list[int]:
"""Return an RGB color value list from a hex color string."""
return [
int(hex_string[i : i + len(hex_string) // 3], 16)
@ -438,14 +439,14 @@ def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
]
def color_temperature_to_hs(color_temperature_kelvin: float) -> Tuple[float, float]:
def color_temperature_to_hs(color_temperature_kelvin: float) -> tuple[float, float]:
"""Return an hs color from a color temperature in Kelvin."""
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
def color_temperature_to_rgb(
color_temperature_kelvin: float,
) -> Tuple[float, float, float]:
) -> tuple[float, float, float]:
"""
Return an RGB color from a color temperature in Kelvin.
@ -555,8 +556,8 @@ def get_closest_point_to_line(A: XYPoint, B: XYPoint, P: XYPoint) -> XYPoint:
def get_closest_point_to_point(
xy_tuple: Tuple[float, float], Gamut: GamutType
) -> Tuple[float, float]:
xy_tuple: tuple[float, float], Gamut: GamutType
) -> tuple[float, float]:
"""
Get the closest matching color within the gamut of the light.
@ -592,7 +593,7 @@ def get_closest_point_to_point(
return (cx, cy)
def check_point_in_lamps_reach(p: Tuple[float, float], Gamut: GamutType) -> bool:
def check_point_in_lamps_reach(p: tuple[float, float], Gamut: GamutType) -> bool:
"""Check if the provided XYPoint can be recreated by a Hue lamp."""
v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y)
v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y)

View File

@ -1,6 +1,8 @@
"""Distance util functions."""
from __future__ import annotations
from numbers import Number
from typing import Callable, Dict
from typing import Callable
from homeassistant.const import (
LENGTH,
@ -26,7 +28,7 @@ VALID_UNITS = [
LENGTH_YARD,
]
TO_METERS: Dict[str, Callable[[float], float]] = {
TO_METERS: dict[str, Callable[[float], float]] = {
LENGTH_METERS: lambda meters: meters,
LENGTH_MILES: lambda miles: miles * 1609.344,
LENGTH_YARD: lambda yards: yards * 0.9144,
@ -37,7 +39,7 @@ TO_METERS: Dict[str, Callable[[float], float]] = {
LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001,
}
METERS_TO: Dict[str, Callable[[float], float]] = {
METERS_TO: dict[str, Callable[[float], float]] = {
LENGTH_METERS: lambda meters: meters,
LENGTH_MILES: lambda meters: meters * 0.000621371,
LENGTH_YARD: lambda meters: meters * 1.09361,

View File

@ -1,7 +1,9 @@
"""Helper methods to handle the time in Home Assistant."""
from __future__ import annotations
import datetime as dt
import re
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, cast
import ciso8601
import pytz
@ -40,7 +42,7 @@ def set_default_time_zone(time_zone: dt.tzinfo) -> None:
DEFAULT_TIME_ZONE = time_zone
def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]:
def get_time_zone(time_zone_str: str) -> dt.tzinfo | None:
"""Get time zone from string. Return None if unable to determine.
Async friendly.
@ -56,7 +58,7 @@ def utcnow() -> dt.datetime:
return dt.datetime.now(NATIVE_UTC)
def now(time_zone: Optional[dt.tzinfo] = None) -> dt.datetime:
def now(time_zone: dt.tzinfo | None = None) -> dt.datetime:
"""Get now in specified time zone."""
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
@ -77,7 +79,7 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
def as_timestamp(dt_value: dt.datetime) -> float:
"""Convert a date/time into a unix time (seconds since 1970)."""
if hasattr(dt_value, "timestamp"):
parsed_dt: Optional[dt.datetime] = dt_value
parsed_dt: dt.datetime | None = dt_value
else:
parsed_dt = parse_datetime(str(dt_value))
if parsed_dt is None:
@ -100,9 +102,7 @@ def utc_from_timestamp(timestamp: float) -> dt.datetime:
return UTC.localize(dt.datetime.utcfromtimestamp(timestamp))
def start_of_local_day(
dt_or_d: Union[dt.date, dt.datetime, None] = None
) -> dt.datetime:
def start_of_local_day(dt_or_d: dt.date | dt.datetime | None = None) -> dt.datetime:
"""Return local datetime object of start of day from date or datetime."""
if dt_or_d is None:
date: dt.date = now().date()
@ -119,7 +119,7 @@ def start_of_local_day(
# Copyright (c) Django Software Foundation and individual contributors.
# All rights reserved.
# https://github.com/django/django/blob/master/LICENSE
def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
def parse_datetime(dt_str: str) -> dt.datetime | None:
"""Parse a string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one,
@ -134,12 +134,12 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
match = DATETIME_RE.match(dt_str)
if not match:
return None
kws: Dict[str, Any] = match.groupdict()
kws: dict[str, Any] = match.groupdict()
if kws["microsecond"]:
kws["microsecond"] = kws["microsecond"].ljust(6, "0")
tzinfo_str = kws.pop("tzinfo")
tzinfo: Optional[dt.tzinfo] = None
tzinfo: dt.tzinfo | None = None
if tzinfo_str == "Z":
tzinfo = UTC
elif tzinfo_str is not None:
@ -154,7 +154,7 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
return dt.datetime(**kws)
def parse_date(dt_str: str) -> Optional[dt.date]:
def parse_date(dt_str: str) -> dt.date | None:
"""Convert a date string to a date object."""
try:
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
@ -162,7 +162,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
return None
def parse_time(time_str: str) -> Optional[dt.time]:
def parse_time(time_str: str) -> dt.time | None:
"""Parse a time string (00:20:00) into Time object.
Return None if invalid.
@ -213,7 +213,7 @@ def get_age(date: dt.datetime) -> str:
return formatn(rounded_delta, selected_unit)
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> List[int]:
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> list[int]:
"""Parse the time expression part and return a list of times to match."""
if parameter is None or parameter == MATCH_ALL:
res = list(range(min_value, max_value + 1))
@ -241,9 +241,9 @@ def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> Lis
def find_next_time_expression_time(
now: dt.datetime, # pylint: disable=redefined-outer-name
seconds: List[int],
minutes: List[int],
hours: List[int],
seconds: list[int],
minutes: list[int],
hours: list[int],
) -> dt.datetime:
"""Find the next datetime from now for which the time expression matches.
@ -257,7 +257,7 @@ def find_next_time_expression_time(
if not seconds or not minutes or not hours:
raise ValueError("Cannot find a next time: Time expression never matches!")
def _lower_bound(arr: List[int], cmp: int) -> Optional[int]:
def _lower_bound(arr: list[int], cmp: int) -> int | None:
"""Return the first value in arr greater or equal to cmp.
Return None if no such value exists.

View File

@ -1,10 +1,12 @@
"""JSON utility functions."""
from __future__ import annotations
from collections import deque
import json
import logging
import os
import tempfile
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
@ -20,9 +22,7 @@ class WriteError(HomeAssistantError):
"""Error writing the data."""
def load_json(
filename: str, default: Union[List, Dict, None] = None
) -> Union[List, Dict]:
def load_json(filename: str, default: list | dict | None = None) -> list | dict:
"""Load JSON data from a file and return as dict or list.
Defaults to returning empty dict if file is not found.
@ -44,10 +44,10 @@ def load_json(
def save_json(
filename: str,
data: Union[List, Dict],
data: list | dict,
private: bool = False,
*,
encoder: Optional[Type[json.JSONEncoder]] = None,
encoder: type[json.JSONEncoder] | None = None,
) -> None:
"""Save JSON data to a file.
@ -85,7 +85,7 @@ def save_json(
_LOGGER.error("JSON replacement cleanup failed: %s", err)
def format_unserializable_data(data: Dict[str, Any]) -> str:
def format_unserializable_data(data: dict[str, Any]) -> str:
"""Format output of find_paths in a friendly way.
Format is comma separated: <path>=<value>(<type>)
@ -95,7 +95,7 @@ def format_unserializable_data(data: Dict[str, Any]) -> str:
def find_paths_unserializable_data(
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Find the paths to unserializable data.
This method is slow! Only use for error handling.

View File

@ -3,10 +3,12 @@ Module with location helpers.
detect_location_info and elevation are mocked by default during tests.
"""
from __future__ import annotations
import asyncio
import collections
import math
from typing import Any, Dict, Optional, Tuple
from typing import Any
import aiohttp
@ -47,7 +49,7 @@ LocationInfo = collections.namedtuple(
async def async_detect_location_info(
session: aiohttp.ClientSession,
) -> Optional[LocationInfo]:
) -> LocationInfo | None:
"""Detect location information."""
data = await _get_ipapi(session)
@ -63,8 +65,8 @@ async def async_detect_location_info(
def distance(
lat1: Optional[float], lon1: Optional[float], lat2: float, lon2: float
) -> Optional[float]:
lat1: float | None, lon1: float | None, lat2: float, lon2: float
) -> float | None:
"""Calculate the distance in meters between two points.
Async friendly.
@ -81,8 +83,8 @@ def distance(
# Source: https://github.com/maurycyp/vincenty
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
def vincenty(
point1: Tuple[float, float], point2: Tuple[float, float], miles: bool = False
) -> Optional[float]:
point1: tuple[float, float], point2: tuple[float, float], miles: bool = False
) -> float | None:
"""
Vincenty formula (inverse method) to calculate the distance.
@ -162,7 +164,7 @@ def vincenty(
return round(s, 6)
async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
async def _get_ipapi(session: aiohttp.ClientSession) -> dict[str, Any] | None:
"""Query ipapi.co for location data."""
try:
resp = await session.get(IPAPI, timeout=5)
@ -192,7 +194,7 @@ async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]
}
async def _get_ip_api(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
async def _get_ip_api(session: aiohttp.ClientSession) -> dict[str, Any] | None:
"""Query ip-api.com for location data."""
try:
resp = await session.get(IP_API, timeout=5)

View File

@ -1,4 +1,6 @@
"""Logging utilities."""
from __future__ import annotations
import asyncio
from functools import partial, wraps
import inspect
@ -6,7 +8,7 @@ import logging
import logging.handlers
import queue
import traceback
from typing import Any, Awaitable, Callable, Coroutine, Union, cast, overload
from typing import Any, Awaitable, Callable, Coroutine, cast, overload
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.core import HomeAssistant, callback
@ -115,7 +117,7 @@ def catch_log_exception(
def catch_log_exception(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> Union[Callable[..., None], Callable[..., Awaitable[None]]]:
) -> Callable[..., None] | Callable[..., Awaitable[None]]:
"""Decorate a callback to catch and log exceptions."""
# Check for partials to properly determine if coroutine function
@ -123,7 +125,7 @@ def catch_log_exception(
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func: Union[Callable[..., None], Callable[..., Awaitable[None]]]
wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]]
if asyncio.iscoroutinefunction(check_func):
async_func = cast(Callable[..., Awaitable[None]], func)

View File

@ -1,6 +1,7 @@
"""Network utilities."""
from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
from typing import Union
import yarl
@ -23,22 +24,22 @@ PRIVATE_NETWORKS = (
LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16")
def is_loopback(address: Union[IPv4Address, IPv6Address]) -> bool:
def is_loopback(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is a loopback address."""
return any(address in network for network in LOOPBACK_NETWORKS)
def is_private(address: Union[IPv4Address, IPv6Address]) -> bool:
def is_private(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is a private address."""
return any(address in network for network in PRIVATE_NETWORKS)
def is_link_local(address: Union[IPv4Address, IPv6Address]) -> bool:
def is_link_local(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is link local."""
return address in LINK_LOCAL_NETWORK
def is_local(address: Union[IPv4Address, IPv6Address]) -> bool:
def is_local(address: IPv4Address | IPv6Address) -> bool:
"""Check if an address is loopback or private."""
return is_loopback(address) or is_private(address)

View File

@ -1,4 +1,6 @@
"""Helpers to install PyPi packages."""
from __future__ import annotations
import asyncio
from importlib.metadata import PackageNotFoundError, version
import logging
@ -6,7 +8,6 @@ import os
from pathlib import Path
from subprocess import PIPE, Popen
import sys
from typing import Optional
from urllib.parse import urlparse
import pkg_resources
@ -59,10 +60,10 @@ def is_installed(package: str) -> bool:
def install_package(
package: str,
upgrade: bool = True,
target: Optional[str] = None,
constraints: Optional[str] = None,
find_links: Optional[str] = None,
no_cache_dir: Optional[bool] = False,
target: str | None = None,
constraints: str | None = None,
find_links: str | None = None,
no_cache_dir: bool | None = False,
) -> bool:
"""Install a package on PyPi. Accepts pip compatible package strings.

View File

@ -1,9 +1,8 @@
"""Percentage util functions."""
from typing import List, Tuple
from __future__ import annotations
def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
def ordered_list_item_to_percentage(ordered_list: list[str], item: str) -> int:
"""Determine the percentage of an item in an ordered list.
When using this utility for fan speeds, do not include "off"
@ -26,7 +25,7 @@ def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
return (list_position * 100) // list_len
def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) -> str:
def percentage_to_ordered_list_item(ordered_list: list[str], percentage: int) -> str:
"""Find the item that most closely matches the percentage in an ordered list.
When using this utility for fan speeds, do not include "off"
@ -54,7 +53,7 @@ def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) ->
def ranged_value_to_percentage(
low_high_range: Tuple[float, float], value: float
low_high_range: tuple[float, float], value: float
) -> int:
"""Given a range of low and high values convert a single value to a percentage.
@ -71,7 +70,7 @@ def ranged_value_to_percentage(
def percentage_to_ranged_value(
low_high_range: Tuple[float, float], percentage: int
low_high_range: tuple[float, float], percentage: int
) -> float:
"""Given a range of low and high values convert a percentage to a single value.
@ -87,11 +86,11 @@ def percentage_to_ranged_value(
return states_in_range(low_high_range) * percentage / 100
def states_in_range(low_high_range: Tuple[float, float]) -> float:
def states_in_range(low_high_range: tuple[float, float]) -> float:
"""Given a range of low and high values return how many states exist."""
return low_high_range[1] - low_high_range[0] + 1
def int_states_in_range(low_high_range: Tuple[float, float]) -> int:
def int_states_in_range(low_high_range: tuple[float, float]) -> int:
"""Given a range of low and high values return how many integer states exist."""
return int(states_in_range(low_high_range))

View File

@ -2,18 +2,18 @@
Can only be used by integrations that have pillow in their requirements.
"""
from typing import Tuple
from __future__ import annotations
from PIL import ImageDraw
def draw_box(
draw: ImageDraw,
box: Tuple[float, float, float, float],
box: tuple[float, float, float, float],
img_width: int,
img_height: int,
text: str = "",
color: Tuple[int, int, int] = (255, 255, 0),
color: tuple[int, int, int] = (255, 255, 0),
) -> None:
"""
Draw a bounding box on and image.

View File

@ -1,9 +1,11 @@
"""ruamel.yaml utility functions."""
from __future__ import annotations
from collections import OrderedDict
import logging
import os
from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union
import ruamel.yaml
from ruamel.yaml import YAML # type: ignore
@ -22,7 +24,7 @@ JSON_TYPE = Union[List, Dict, str] # pylint: disable=invalid-name
class ExtSafeConstructor(SafeConstructor):
"""Extended SafeConstructor."""
name: Optional[str] = None
name: str | None = None
class UnsupportedYamlError(HomeAssistantError):
@ -77,7 +79,7 @@ def yaml_to_object(data: str) -> JSON_TYPE:
"""Create object from yaml string."""
yaml = YAML(typ="rt")
try:
result: Union[List, Dict, str] = yaml.load(data)
result: list | dict | str = yaml.load(data)
return result
except YAMLError as exc:
_LOGGER.error("YAML error: %s", exc)

View File

@ -8,7 +8,7 @@ from __future__ import annotations
import asyncio
import enum
from types import TracebackType
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any
from .async_ import run_callback_threadsafe
@ -38,10 +38,10 @@ class _GlobalFreezeContext:
async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._exit()
return None
@ -51,10 +51,10 @@ class _GlobalFreezeContext:
def __exit__( # pylint: disable=useless-return
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._loop.call_soon_threadsafe(self._exit)
return None
@ -106,10 +106,10 @@ class _ZoneFreezeContext:
async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._exit()
return None
@ -119,10 +119,10 @@ class _ZoneFreezeContext:
def __exit__( # pylint: disable=useless-return
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._loop.call_soon_threadsafe(self._exit)
return None
@ -155,8 +155,8 @@ class _GlobalTaskContext:
self._manager: TimeoutManager = manager
self._task: asyncio.Task[Any] = task
self._time_left: float = timeout
self._expiration_time: Optional[float] = None
self._timeout_handler: Optional[asyncio.Handle] = None
self._expiration_time: float | None = None
self._timeout_handler: asyncio.Handle | None = None
self._wait_zone: asyncio.Event = asyncio.Event()
self._state: _State = _State.INIT
self._cool_down: float = cool_down
@ -169,10 +169,10 @@ class _GlobalTaskContext:
async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._stop_timer()
self._manager.global_tasks.remove(self)
@ -263,8 +263,8 @@ class _ZoneTaskContext:
self._task: asyncio.Task[Any] = task
self._state: _State = _State.INIT
self._time_left: float = timeout
self._expiration_time: Optional[float] = None
self._timeout_handler: Optional[asyncio.Handle] = None
self._expiration_time: float | None = None
self._timeout_handler: asyncio.Handle | None = None
@property
def state(self) -> _State:
@ -283,10 +283,10 @@ class _ZoneTaskContext:
async def __aexit__(
self,
exc_type: Type[BaseException],
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> Optional[bool]:
) -> bool | None:
self._zone.exit_task(self)
self._stop_timer()
@ -344,8 +344,8 @@ class _ZoneTimeoutManager:
"""Initialize internal timeout context manager."""
self._manager: TimeoutManager = manager
self._zone: str = zone
self._tasks: List[_ZoneTaskContext] = []
self._freezes: List[_ZoneFreezeContext] = []
self._tasks: list[_ZoneTaskContext] = []
self._freezes: list[_ZoneFreezeContext] = []
def __repr__(self) -> str:
"""Representation of a zone."""
@ -418,9 +418,9 @@ class TimeoutManager:
def __init__(self) -> None:
"""Initialize TimeoutManager."""
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._zones: Dict[str, _ZoneTimeoutManager] = {}
self._globals: List[_GlobalTaskContext] = []
self._freezes: List[_GlobalFreezeContext] = []
self._zones: dict[str, _ZoneTimeoutManager] = {}
self._globals: list[_GlobalTaskContext] = []
self._freezes: list[_GlobalFreezeContext] = []
@property
def zones_done(self) -> bool:
@ -433,17 +433,17 @@ class TimeoutManager:
return not self._freezes
@property
def zones(self) -> Dict[str, _ZoneTimeoutManager]:
def zones(self) -> dict[str, _ZoneTimeoutManager]:
"""Return all Zones."""
return self._zones
@property
def global_tasks(self) -> List[_GlobalTaskContext]:
def global_tasks(self) -> list[_GlobalTaskContext]:
"""Return all global Tasks."""
return self._globals
@property
def global_freezes(self) -> List[_GlobalFreezeContext]:
def global_freezes(self) -> list[_GlobalFreezeContext]:
"""Return all global Freezes."""
return self._freezes
@ -459,12 +459,12 @@ class TimeoutManager:
def async_timeout(
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
) -> Union[_ZoneTaskContext, _GlobalTaskContext]:
) -> _ZoneTaskContext | _GlobalTaskContext:
"""Timeout based on a zone.
For using as Async Context Manager.
"""
current_task: Optional[asyncio.Task[Any]] = asyncio.current_task()
current_task: asyncio.Task[Any] | None = asyncio.current_task()
assert current_task
# Global Zone
@ -483,7 +483,7 @@ class TimeoutManager:
def async_freeze(
self, zone_name: str = ZONE_GLOBAL
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
) -> _ZoneFreezeContext | _GlobalFreezeContext:
"""Freeze all timer until job is done.
For using as Async Context Manager.
@ -502,7 +502,7 @@ class TimeoutManager:
def freeze(
self, zone_name: str = ZONE_GLOBAL
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
) -> _ZoneFreezeContext | _GlobalFreezeContext:
"""Freeze all timer until job is done.
For using as Context Manager.

View File

@ -1,6 +1,7 @@
"""Unit system helper class and methods."""
from __future__ import annotations
from numbers import Number
from typing import Dict, Optional
from homeassistant.const import (
CONF_UNIT_SYSTEM_IMPERIAL,
@ -109,7 +110,7 @@ class UnitSystem:
return temperature_util.convert(temperature, from_unit, self.temperature_unit)
def length(self, length: Optional[float], from_unit: str) -> float:
def length(self, length: float | None, from_unit: str) -> float:
"""Convert the given length to this unit system."""
if not isinstance(length, Number):
raise TypeError(f"{length!s} is not a numeric value.")
@ -119,7 +120,7 @@ class UnitSystem:
length, from_unit, self.length_unit
)
def pressure(self, pressure: Optional[float], from_unit: str) -> float:
def pressure(self, pressure: float | None, from_unit: str) -> float:
"""Convert the given pressure to this unit system."""
if not isinstance(pressure, Number):
raise TypeError(f"{pressure!s} is not a numeric value.")
@ -129,7 +130,7 @@ class UnitSystem:
pressure, from_unit, self.pressure_unit
)
def volume(self, volume: Optional[float], from_unit: str) -> float:
def volume(self, volume: float | None, from_unit: str) -> float:
"""Convert the given volume to this unit system."""
if not isinstance(volume, Number):
raise TypeError(f"{volume!s} is not a numeric value.")
@ -137,7 +138,7 @@ class UnitSystem:
# type ignore: https://github.com/python/mypy/issues/7207
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
def as_dict(self) -> Dict[str, str]:
def as_dict(self) -> dict[str, str]:
"""Convert the unit system to a dictionary."""
return {
LENGTH: self.length_unit,

View File

@ -1,6 +1,7 @@
"""Deal with YAML input."""
from __future__ import annotations
from typing import Any, Dict, Set
from typing import Any
from .objects import Input
@ -14,14 +15,14 @@ class UndefinedSubstitution(Exception):
self.input = input
def extract_inputs(obj: Any) -> Set[str]:
def extract_inputs(obj: Any) -> set[str]:
"""Extract input from a structure."""
found: Set[str] = set()
found: set[str] = set()
_extract_inputs(obj, found)
return found
def _extract_inputs(obj: Any, found: Set[str]) -> None:
def _extract_inputs(obj: Any, found: set[str]) -> None:
"""Extract input from a structure."""
if isinstance(obj, Input):
found.add(obj.name)
@ -38,7 +39,7 @@ def _extract_inputs(obj: Any, found: Set[str]) -> None:
return
def substitute(obj: Any, substitutions: Dict[str, Any]) -> Any:
def substitute(obj: Any, substitutions: dict[str, Any]) -> Any:
"""Substitute values."""
if isinstance(obj, Input):
if obj.name not in substitutions:

View File

@ -1,10 +1,12 @@
"""Custom loader."""
from __future__ import annotations
from collections import OrderedDict
import fnmatch
import logging
import os
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union, overload
from typing import Any, Dict, Iterator, List, TextIO, TypeVar, Union, overload
import yaml
@ -27,7 +29,7 @@ class Secrets:
def __init__(self, config_dir: Path):
"""Initialize secrets."""
self.config_dir = config_dir
self._cache: Dict[Path, Dict[str, str]] = {}
self._cache: dict[Path, dict[str, str]] = {}
def get(self, requester_path: str, secret: str) -> str:
"""Return the value of a secret."""
@ -55,7 +57,7 @@ class Secrets:
raise HomeAssistantError(f"Secret {secret} not defined")
def _load_secret_yaml(self, secret_dir: Path) -> Dict[str, str]:
def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
"""Load the secrets yaml from path."""
secret_path = secret_dir / SECRET_YAML
@ -90,7 +92,7 @@ class Secrets:
class SafeLineLoader(yaml.SafeLoader):
"""Loader class that keeps track of line numbers."""
def __init__(self, stream: Any, secrets: Optional[Secrets] = None) -> None:
def __init__(self, stream: Any, secrets: Secrets | None = None) -> None:
"""Initialize a safe line loader."""
super().__init__(stream)
self.secrets = secrets
@ -103,7 +105,7 @@ class SafeLineLoader(yaml.SafeLoader):
return node
def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
def load_yaml(fname: str, secrets: Secrets | None = None) -> JSON_TYPE:
"""Load a YAML file."""
try:
with open(fname, encoding="utf-8") as conf_file:
@ -113,9 +115,7 @@ def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
raise HomeAssistantError(exc) from exc
def parse_yaml(
content: Union[str, TextIO], secrets: Optional[Secrets] = None
) -> JSON_TYPE:
def parse_yaml(content: str | TextIO, secrets: Secrets | None = None) -> JSON_TYPE:
"""Load a YAML file."""
try:
# If configuration file is empty YAML returns None
@ -131,14 +131,14 @@ def parse_yaml(
@overload
def _add_reference(
obj: Union[list, NodeListClass], loader: SafeLineLoader, node: yaml.nodes.Node
obj: list | NodeListClass, loader: SafeLineLoader, node: yaml.nodes.Node
) -> NodeListClass:
...
@overload
def _add_reference(
obj: Union[str, NodeStrClass], loader: SafeLineLoader, node: yaml.nodes.Node
obj: str | NodeStrClass, loader: SafeLineLoader, node: yaml.nodes.Node
) -> NodeStrClass:
...
@ -223,7 +223,7 @@ def _include_dir_merge_named_yaml(
def _include_dir_list_yaml(
loader: SafeLineLoader, node: yaml.nodes.Node
) -> List[JSON_TYPE]:
) -> list[JSON_TYPE]:
"""Load multiple files from directory as a list."""
loc = os.path.join(os.path.dirname(loader.name), node.value)
return [
@ -238,7 +238,7 @@ def _include_dir_merge_list_yaml(
) -> JSON_TYPE:
"""Load multiple files from directory as a merged list."""
loc: str = os.path.join(os.path.dirname(loader.name), node.value)
merged_list: List[JSON_TYPE] = []
merged_list: list[JSON_TYPE] = []
for fname in _find_files(loc, "*.yaml"):
if os.path.basename(fname) == SECRET_YAML:
continue
@ -253,7 +253,7 @@ def _ordered_dict(loader: SafeLineLoader, node: yaml.nodes.MappingNode) -> Order
loader.flatten_mapping(node)
nodes = loader.construct_pairs(node)
seen: Dict = {}
seen: dict = {}
for (key, _), (child_node, _) in zip(nodes, node.value):
line = child_node.start_mark.line