Update typing 03 (#48015)
parent
6fb2e63e49
commit
fabd73f08b
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, Optional, Tuple, cast
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -36,8 +36,8 @@ class InvalidProvider(Exception):
|
|||
|
||||
async def auth_manager_from_config(
|
||||
hass: HomeAssistant,
|
||||
provider_configs: List[Dict[str, Any]],
|
||||
module_configs: List[Dict[str, Any]],
|
||||
provider_configs: list[dict[str, Any]],
|
||||
module_configs: list[dict[str, Any]],
|
||||
) -> AuthManager:
|
||||
"""Initialize an auth manager from config.
|
||||
|
||||
|
@ -87,8 +87,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
self,
|
||||
handler_key: Any,
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> data_entry_flow.FlowHandler:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||
|
@ -97,8 +97,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||
return await auth_provider.async_login_flow(context)
|
||||
|
||||
async def async_finish_flow(
|
||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
self, flow: data_entry_flow.FlowHandler, result: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Return a user as result of login flow."""
|
||||
flow = cast(LoginFlow, flow)
|
||||
|
||||
|
@ -157,22 +157,22 @@ class AuthManager:
|
|||
self.login_flow = AuthManagerFlowManager(hass, self)
|
||||
|
||||
@property
|
||||
def auth_providers(self) -> List[AuthProvider]:
|
||||
def auth_providers(self) -> list[AuthProvider]:
|
||||
"""Return a list of available auth providers."""
|
||||
return list(self._providers.values())
|
||||
|
||||
@property
|
||||
def auth_mfa_modules(self) -> List[MultiFactorAuthModule]:
|
||||
def auth_mfa_modules(self) -> list[MultiFactorAuthModule]:
|
||||
"""Return a list of available auth modules."""
|
||||
return list(self._mfa_modules.values())
|
||||
|
||||
def get_auth_provider(
|
||||
self, provider_type: str, provider_id: Optional[str]
|
||||
) -> Optional[AuthProvider]:
|
||||
self, provider_type: str, provider_id: str | None
|
||||
) -> AuthProvider | None:
|
||||
"""Return an auth provider, None if not found."""
|
||||
return self._providers.get((provider_type, provider_id))
|
||||
|
||||
def get_auth_providers(self, provider_type: str) -> List[AuthProvider]:
|
||||
def get_auth_providers(self, provider_type: str) -> list[AuthProvider]:
|
||||
"""Return a List of auth provider of one type, Empty if not found."""
|
||||
return [
|
||||
provider
|
||||
|
@ -180,30 +180,30 @@ class AuthManager:
|
|||
if p_type == provider_type
|
||||
]
|
||||
|
||||
def get_auth_mfa_module(self, module_id: str) -> Optional[MultiFactorAuthModule]:
|
||||
def get_auth_mfa_module(self, module_id: str) -> MultiFactorAuthModule | None:
|
||||
"""Return a multi-factor auth module, None if not found."""
|
||||
return self._mfa_modules.get(module_id)
|
||||
|
||||
async def async_get_users(self) -> List[models.User]:
|
||||
async def async_get_users(self) -> list[models.User]:
|
||||
"""Retrieve all users."""
|
||||
return await self._store.async_get_users()
|
||||
|
||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||
async def async_get_user(self, user_id: str) -> models.User | None:
|
||||
"""Retrieve a user."""
|
||||
return await self._store.async_get_user(user_id)
|
||||
|
||||
async def async_get_owner(self) -> Optional[models.User]:
|
||||
async def async_get_owner(self) -> models.User | None:
|
||||
"""Retrieve the owner."""
|
||||
users = await self.async_get_users()
|
||||
return next((user for user in users if user.is_owner), None)
|
||||
|
||||
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
||||
async def async_get_group(self, group_id: str) -> models.Group | None:
|
||||
"""Retrieve all groups."""
|
||||
return await self._store.async_get_group(group_id)
|
||||
|
||||
async def async_get_user_by_credentials(
|
||||
self, credentials: models.Credentials
|
||||
) -> Optional[models.User]:
|
||||
) -> models.User | None:
|
||||
"""Get a user by credential, return None if not found."""
|
||||
for user in await self.async_get_users():
|
||||
for creds in user.credentials:
|
||||
|
@ -213,7 +213,7 @@ class AuthManager:
|
|||
return None
|
||||
|
||||
async def async_create_system_user(
|
||||
self, name: str, group_ids: Optional[List[str]] = None
|
||||
self, name: str, group_ids: list[str] | None = None
|
||||
) -> models.User:
|
||||
"""Create a system user."""
|
||||
user = await self._store.async_create_user(
|
||||
|
@ -225,10 +225,10 @@ class AuthManager:
|
|||
return user
|
||||
|
||||
async def async_create_user(
|
||||
self, name: str, group_ids: Optional[List[str]] = None
|
||||
self, name: str, group_ids: list[str] | None = None
|
||||
) -> models.User:
|
||||
"""Create a user."""
|
||||
kwargs: Dict[str, Any] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"name": name,
|
||||
"is_active": True,
|
||||
"group_ids": group_ids or [],
|
||||
|
@ -294,12 +294,12 @@ class AuthManager:
|
|||
async def async_update_user(
|
||||
self,
|
||||
user: models.User,
|
||||
name: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
group_ids: Optional[List[str]] = None,
|
||||
name: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Update a user."""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
if name is not None:
|
||||
kwargs["name"] = name
|
||||
if group_ids is not None:
|
||||
|
@ -362,9 +362,9 @@ class AuthManager:
|
|||
|
||||
await module.async_depose_user(user.id)
|
||||
|
||||
async def async_get_enabled_mfa(self, user: models.User) -> Dict[str, str]:
|
||||
async def async_get_enabled_mfa(self, user: models.User) -> dict[str, str]:
|
||||
"""List enabled mfa modules for user."""
|
||||
modules: Dict[str, str] = OrderedDict()
|
||||
modules: dict[str, str] = OrderedDict()
|
||||
for module_id, module in self._mfa_modules.items():
|
||||
if await module.async_is_user_setup(user.id):
|
||||
modules[module_id] = module.name
|
||||
|
@ -373,12 +373,12 @@ class AuthManager:
|
|||
async def async_create_refresh_token(
|
||||
self,
|
||||
user: models.User,
|
||||
client_id: Optional[str] = None,
|
||||
client_name: Optional[str] = None,
|
||||
client_icon: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
client_id: str | None = None,
|
||||
client_name: str | None = None,
|
||||
client_icon: str | None = None,
|
||||
token_type: str | None = None,
|
||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||
credential: Optional[models.Credentials] = None,
|
||||
credential: models.Credentials | None = None,
|
||||
) -> models.RefreshToken:
|
||||
"""Create a new refresh token for a user."""
|
||||
if not user.is_active:
|
||||
|
@ -432,13 +432,13 @@ class AuthManager:
|
|||
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
return await self._store.async_get_refresh_token(token_id)
|
||||
|
||||
async def async_get_refresh_token_by_token(
|
||||
self, token: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by token."""
|
||||
return await self._store.async_get_refresh_token_by_token(token)
|
||||
|
||||
|
@ -450,7 +450,7 @@ class AuthManager:
|
|||
|
||||
@callback
|
||||
def async_create_access_token(
|
||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||
) -> str:
|
||||
"""Create a new access token."""
|
||||
self.async_validate_refresh_token(refresh_token, remote_ip)
|
||||
|
@ -471,7 +471,7 @@ class AuthManager:
|
|||
@callback
|
||||
def _async_resolve_provider(
|
||||
self, refresh_token: models.RefreshToken
|
||||
) -> Optional[AuthProvider]:
|
||||
) -> AuthProvider | None:
|
||||
"""Get the auth provider for the given refresh token.
|
||||
|
||||
Raises an exception if the expected provider is no longer available or return
|
||||
|
@ -492,7 +492,7 @@ class AuthManager:
|
|||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||
) -> None:
|
||||
"""Validate that a refresh token is usable.
|
||||
|
||||
|
@ -504,7 +504,7 @@ class AuthManager:
|
|||
|
||||
async def async_validate_access_token(
|
||||
self, token: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
) -> models.RefreshToken | None:
|
||||
"""Return refresh token if an access token is valid."""
|
||||
try:
|
||||
unverif_claims = jwt.decode(token, verify=False)
|
||||
|
@ -535,7 +535,7 @@ class AuthManager:
|
|||
@callback
|
||||
def _async_get_auth_provider(
|
||||
self, credentials: models.Credentials
|
||||
) -> Optional[AuthProvider]:
|
||||
) -> AuthProvider | None:
|
||||
"""Get auth provider from a set of credentials."""
|
||||
auth_provider_key = (
|
||||
credentials.auth_provider_type,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
"""Storage for auth models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
import hmac
|
||||
from logging import getLogger
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
@ -34,15 +36,15 @@ class AuthStore:
|
|||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self._users: Optional[Dict[str, models.User]] = None
|
||||
self._groups: Optional[Dict[str, models.Group]] = None
|
||||
self._perm_lookup: Optional[PermissionLookup] = None
|
||||
self._users: dict[str, models.User] | None = None
|
||||
self._groups: dict[str, models.Group] | None = None
|
||||
self._perm_lookup: PermissionLookup | None = None
|
||||
self._store = hass.helpers.storage.Store(
|
||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||
)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def async_get_groups(self) -> List[models.Group]:
|
||||
async def async_get_groups(self) -> list[models.Group]:
|
||||
"""Retrieve all users."""
|
||||
if self._groups is None:
|
||||
await self._async_load()
|
||||
|
@ -50,7 +52,7 @@ class AuthStore:
|
|||
|
||||
return list(self._groups.values())
|
||||
|
||||
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
||||
async def async_get_group(self, group_id: str) -> models.Group | None:
|
||||
"""Retrieve all users."""
|
||||
if self._groups is None:
|
||||
await self._async_load()
|
||||
|
@ -58,7 +60,7 @@ class AuthStore:
|
|||
|
||||
return self._groups.get(group_id)
|
||||
|
||||
async def async_get_users(self) -> List[models.User]:
|
||||
async def async_get_users(self) -> list[models.User]:
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
@ -66,7 +68,7 @@ class AuthStore:
|
|||
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
||||
async def async_get_user(self, user_id: str) -> models.User | None:
|
||||
"""Retrieve a user by id."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
@ -76,12 +78,12 @@ class AuthStore:
|
|||
|
||||
async def async_create_user(
|
||||
self,
|
||||
name: Optional[str],
|
||||
is_owner: Optional[bool] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
system_generated: Optional[bool] = None,
|
||||
credentials: Optional[models.Credentials] = None,
|
||||
group_ids: Optional[List[str]] = None,
|
||||
name: str | None,
|
||||
is_owner: bool | None = None,
|
||||
is_active: bool | None = None,
|
||||
system_generated: bool | None = None,
|
||||
credentials: models.Credentials | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
) -> models.User:
|
||||
"""Create a new user."""
|
||||
if self._users is None:
|
||||
|
@ -97,7 +99,7 @@ class AuthStore:
|
|||
raise ValueError(f"Invalid group specified {group_id}")
|
||||
groups.append(group)
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"name": name,
|
||||
# Until we get group management, we just put everyone in the
|
||||
# same group.
|
||||
|
@ -146,9 +148,9 @@ class AuthStore:
|
|||
async def async_update_user(
|
||||
self,
|
||||
user: models.User,
|
||||
name: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
group_ids: Optional[List[str]] = None,
|
||||
name: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
group_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Update a user."""
|
||||
assert self._groups is not None
|
||||
|
@ -203,15 +205,15 @@ class AuthStore:
|
|||
async def async_create_refresh_token(
|
||||
self,
|
||||
user: models.User,
|
||||
client_id: Optional[str] = None,
|
||||
client_name: Optional[str] = None,
|
||||
client_icon: Optional[str] = None,
|
||||
client_id: str | None = None,
|
||||
client_name: str | None = None,
|
||||
client_icon: str | None = None,
|
||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||
credential: Optional[models.Credentials] = None,
|
||||
credential: models.Credentials | None = None,
|
||||
) -> models.RefreshToken:
|
||||
"""Create a new token for a user."""
|
||||
kwargs: Dict[str, Any] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"user": user,
|
||||
"client_id": client_id,
|
||||
"token_type": token_type,
|
||||
|
@ -244,7 +246,7 @@ class AuthStore:
|
|||
|
||||
async def async_get_refresh_token(
|
||||
self, token_id: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by id."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
@ -259,7 +261,7 @@ class AuthStore:
|
|||
|
||||
async def async_get_refresh_token_by_token(
|
||||
self, token: str
|
||||
) -> Optional[models.RefreshToken]:
|
||||
) -> models.RefreshToken | None:
|
||||
"""Get refresh token by token."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
@ -276,7 +278,7 @@ class AuthStore:
|
|||
|
||||
@callback
|
||||
def async_log_refresh_token_usage(
|
||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||
) -> None:
|
||||
"""Update refresh token last used information."""
|
||||
refresh_token.last_used_at = dt_util.utcnow()
|
||||
|
@ -309,9 +311,9 @@ class AuthStore:
|
|||
self._set_defaults()
|
||||
return
|
||||
|
||||
users: Dict[str, models.User] = OrderedDict()
|
||||
groups: Dict[str, models.Group] = OrderedDict()
|
||||
credentials: Dict[str, models.Credentials] = OrderedDict()
|
||||
users: dict[str, models.User] = OrderedDict()
|
||||
groups: dict[str, models.Group] = OrderedDict()
|
||||
credentials: dict[str, models.Credentials] = OrderedDict()
|
||||
|
||||
# Soft-migrating data as we load. We are going to make sure we have a
|
||||
# read only group and an admin group. There are two states that we can
|
||||
|
@ -328,7 +330,7 @@ class AuthStore:
|
|||
# was added.
|
||||
|
||||
for group_dict in data.get("groups", []):
|
||||
policy: Optional[PolicyType] = None
|
||||
policy: PolicyType | None = None
|
||||
|
||||
if group_dict["id"] == GROUP_ID_ADMIN:
|
||||
has_admin_group = True
|
||||
|
@ -489,7 +491,7 @@ class AuthStore:
|
|||
self._store.async_delay_save(self._data_to_save, 1)
|
||||
|
||||
@callback
|
||||
def _data_to_save(self) -> Dict:
|
||||
def _data_to_save(self) -> dict:
|
||||
"""Return the data to store."""
|
||||
assert self._users is not None
|
||||
assert self._groups is not None
|
||||
|
@ -508,7 +510,7 @@ class AuthStore:
|
|||
|
||||
groups = []
|
||||
for group in self._groups.values():
|
||||
g_dict: Dict[str, Any] = {
|
||||
g_dict: dict[str, Any] = {
|
||||
"id": group.id,
|
||||
# Name not read for sys groups. Kept here for backwards compat
|
||||
"name": group.name,
|
||||
|
@ -567,7 +569,7 @@ class AuthStore:
|
|||
"""Set default values for auth store."""
|
||||
self._users = OrderedDict()
|
||||
|
||||
groups: Dict[str, models.Group] = OrderedDict()
|
||||
groups: dict[str, models.Group] = OrderedDict()
|
||||
admin_group = _system_admin_group()
|
||||
groups[admin_group.id] = admin_group
|
||||
user_group = _system_user_group()
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import importlib
|
||||
import logging
|
||||
import types
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
@ -38,7 +38,7 @@ class MultiFactorAuthModule:
|
|||
DEFAULT_TITLE = "Unnamed auth module"
|
||||
MAX_RETRY_TIME = 3
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Initialize an auth module."""
|
||||
self.hass = hass
|
||||
self.config = config
|
||||
|
@ -87,7 +87,7 @@ class MultiFactorAuthModule:
|
|||
"""Return whether user is setup."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
||||
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||
"""Return True if validation passed."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -104,14 +104,14 @@ class SetupFlow(data_entry_flow.FlowHandler):
|
|||
self._user_id = user_id
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the first step of setup flow.
|
||||
|
||||
Return self.async_show_form(step_id='init') if user_input is None.
|
||||
Return self.async_create_entry(data={'result': result}) if finish.
|
||||
"""
|
||||
errors: Dict[str, str] = {}
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input:
|
||||
result = await self._auth_module.async_setup_user(self._user_id, user_input)
|
||||
|
@ -125,7 +125,7 @@ class SetupFlow(data_entry_flow.FlowHandler):
|
|||
|
||||
|
||||
async def auth_mfa_module_from_config(
|
||||
hass: HomeAssistant, config: Dict[str, Any]
|
||||
hass: HomeAssistant, config: dict[str, Any]
|
||||
) -> MultiFactorAuthModule:
|
||||
"""Initialize an auth module from a config."""
|
||||
module_name = config[CONF_TYPE]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""Example auth module."""
|
||||
from typing import Any, Dict
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -28,7 +30,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
|
|||
|
||||
DEFAULT_TITLE = "Insecure Personal Identify Number"
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Initialize the user data store."""
|
||||
super().__init__(hass, config)
|
||||
self._data = config["data"]
|
||||
|
@ -80,7 +82,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
||||
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||
"""Return True if validation passed."""
|
||||
for data in self._data:
|
||||
if data["user_id"] == user_id:
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
Sending HOTP through notify service
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
|
@ -79,8 +81,8 @@ class NotifySetting:
|
|||
|
||||
secret: str = attr.ib(factory=_generate_secret) # not persistent
|
||||
counter: int = attr.ib(factory=_generate_random) # not persistent
|
||||
notify_service: Optional[str] = attr.ib(default=None)
|
||||
target: Optional[str] = attr.ib(default=None)
|
||||
notify_service: str | None = attr.ib(default=None)
|
||||
target: str | None = attr.ib(default=None)
|
||||
|
||||
|
||||
_UsersDict = Dict[str, NotifySetting]
|
||||
|
@ -92,10 +94,10 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
|
||||
DEFAULT_TITLE = "Notify One-Time Password"
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Initialize the user data store."""
|
||||
super().__init__(hass, config)
|
||||
self._user_settings: Optional[_UsersDict] = None
|
||||
self._user_settings: _UsersDict | None = None
|
||||
self._user_store = hass.helpers.storage.Store(
|
||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||
)
|
||||
|
@ -146,7 +148,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
)
|
||||
|
||||
@callback
|
||||
def aync_get_available_notify_services(self) -> List[str]:
|
||||
def aync_get_available_notify_services(self) -> list[str]:
|
||||
"""Return list of notify services."""
|
||||
unordered_services = set()
|
||||
|
||||
|
@ -198,7 +200,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
|
||||
return user_id in self._user_settings
|
||||
|
||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
||||
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||
"""Return True if validation passed."""
|
||||
if self._user_settings is None:
|
||||
await self._async_load()
|
||||
|
@ -258,7 +260,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||
)
|
||||
|
||||
async def async_notify(
|
||||
self, code: str, notify_service: str, target: Optional[str] = None
|
||||
self, code: str, notify_service: str, target: str | None = None
|
||||
) -> None:
|
||||
"""Send code by notify service."""
|
||||
data = {"message": self._message_template.format(code)}
|
||||
|
@ -276,23 +278,23 @@ class NotifySetupFlow(SetupFlow):
|
|||
auth_module: NotifyAuthModule,
|
||||
setup_schema: vol.Schema,
|
||||
user_id: str,
|
||||
available_notify_services: List[str],
|
||||
available_notify_services: list[str],
|
||||
) -> None:
|
||||
"""Initialize the setup flow."""
|
||||
super().__init__(auth_module, setup_schema, user_id)
|
||||
# to fix typing complaint
|
||||
self._auth_module: NotifyAuthModule = auth_module
|
||||
self._available_notify_services = available_notify_services
|
||||
self._secret: Optional[str] = None
|
||||
self._count: Optional[int] = None
|
||||
self._notify_service: Optional[str] = None
|
||||
self._target: Optional[str] = None
|
||||
self._secret: str | None = None
|
||||
self._count: int | None = None
|
||||
self._notify_service: str | None = None
|
||||
self._target: str | None = None
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Let user select available notify services."""
|
||||
errors: Dict[str, str] = {}
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
hass = self._auth_module.hass
|
||||
if user_input:
|
||||
|
@ -306,7 +308,7 @@ class NotifySetupFlow(SetupFlow):
|
|||
if not self._available_notify_services:
|
||||
return self.async_abort(reason="no_available_service")
|
||||
|
||||
schema: Dict[str, Any] = OrderedDict()
|
||||
schema: dict[str, Any] = OrderedDict()
|
||||
schema["notify_service"] = vol.In(self._available_notify_services)
|
||||
schema["target"] = vol.Optional(str)
|
||||
|
||||
|
@ -315,10 +317,10 @@ class NotifySetupFlow(SetupFlow):
|
|||
)
|
||||
|
||||
async def async_step_setup(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Verify user can receive one-time password."""
|
||||
errors: Dict[str, str] = {}
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
hass = self._auth_module.hass
|
||||
if user_input:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Time-based One Time Password auth module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -50,7 +52,7 @@ def _generate_qr_code(data: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
def _generate_secret_and_qr_code(username: str) -> Tuple[str, str, str]:
|
||||
def _generate_secret_and_qr_code(username: str) -> tuple[str, str, str]:
|
||||
"""Generate a secret, url, and QR code."""
|
||||
import pyotp # pylint: disable=import-outside-toplevel
|
||||
|
||||
|
@ -69,10 +71,10 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
DEFAULT_TITLE = "Time-based One Time Password"
|
||||
MAX_RETRY_TIME = 5
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
||||
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Initialize the user data store."""
|
||||
super().__init__(hass, config)
|
||||
self._users: Optional[Dict[str, str]] = None
|
||||
self._users: dict[str, str] | None = None
|
||||
self._user_store = hass.helpers.storage.Store(
|
||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||
)
|
||||
|
@ -100,7 +102,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
"""Save data."""
|
||||
await self._user_store.async_save({STORAGE_USERS: self._users})
|
||||
|
||||
def _add_ota_secret(self, user_id: str, secret: Optional[str] = None) -> str:
|
||||
def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
|
||||
"""Create a ota_secret for user."""
|
||||
import pyotp # pylint: disable=import-outside-toplevel
|
||||
|
||||
|
@ -145,7 +147,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||
|
||||
return user_id in self._users # type: ignore
|
||||
|
||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
||||
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||
"""Return True if validation passed."""
|
||||
if self._users is None:
|
||||
await self._async_load()
|
||||
|
@ -181,13 +183,13 @@ class TotpSetupFlow(SetupFlow):
|
|||
# to fix typing complaint
|
||||
self._auth_module: TotpAuthModule = auth_module
|
||||
self._user = user
|
||||
self._ota_secret: Optional[str] = None
|
||||
self._ota_secret: str | None = None
|
||||
self._url = None # type Optional[str]
|
||||
self._image = None # type Optional[str]
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the first step of setup flow.
|
||||
|
||||
Return self.async_show_form(step_id='init') if user_input is None.
|
||||
|
@ -195,7 +197,7 @@ class TotpSetupFlow(SetupFlow):
|
|||
"""
|
||||
import pyotp # pylint: disable=import-outside-toplevel
|
||||
|
||||
errors: Dict[str, str] = {}
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input:
|
||||
verified = await self.hass.async_add_executor_job(
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Auth models."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
from typing import NamedTuple
|
||||
import uuid
|
||||
|
||||
import attr
|
||||
|
@ -21,7 +23,7 @@ TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
|||
class Group:
|
||||
"""A group."""
|
||||
|
||||
name: Optional[str] = attr.ib()
|
||||
name: str | None = attr.ib()
|
||||
policy: perm_mdl.PolicyType = attr.ib()
|
||||
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
||||
system_generated: bool = attr.ib(default=False)
|
||||
|
@ -31,24 +33,24 @@ class Group:
|
|||
class User:
|
||||
"""A user."""
|
||||
|
||||
name: Optional[str] = attr.ib()
|
||||
name: str | None = attr.ib()
|
||||
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
|
||||
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
||||
is_owner: bool = attr.ib(default=False)
|
||||
is_active: bool = attr.ib(default=False)
|
||||
system_generated: bool = attr.ib(default=False)
|
||||
|
||||
groups: List[Group] = attr.ib(factory=list, eq=False, order=False)
|
||||
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
|
||||
|
||||
# List of credentials of a user.
|
||||
credentials: List["Credentials"] = attr.ib(factory=list, eq=False, order=False)
|
||||
credentials: list["Credentials"] = attr.ib(factory=list, eq=False, order=False)
|
||||
|
||||
# Tokens associated with a user.
|
||||
refresh_tokens: Dict[str, "RefreshToken"] = attr.ib(
|
||||
refresh_tokens: dict[str, "RefreshToken"] = attr.ib(
|
||||
factory=dict, eq=False, order=False
|
||||
)
|
||||
|
||||
_permissions: Optional[perm_mdl.PolicyPermissions] = attr.ib(
|
||||
_permissions: perm_mdl.PolicyPermissions | None = attr.ib(
|
||||
init=False,
|
||||
eq=False,
|
||||
order=False,
|
||||
|
@ -89,10 +91,10 @@ class RefreshToken:
|
|||
"""RefreshToken for a user to grant new access tokens."""
|
||||
|
||||
user: User = attr.ib()
|
||||
client_id: Optional[str] = attr.ib()
|
||||
client_id: str | None = attr.ib()
|
||||
access_token_expiration: timedelta = attr.ib()
|
||||
client_name: Optional[str] = attr.ib(default=None)
|
||||
client_icon: Optional[str] = attr.ib(default=None)
|
||||
client_name: str | None = attr.ib(default=None)
|
||||
client_icon: str | None = attr.ib(default=None)
|
||||
token_type: str = attr.ib(
|
||||
default=TOKEN_TYPE_NORMAL,
|
||||
validator=attr.validators.in_(
|
||||
|
@ -104,12 +106,12 @@ class RefreshToken:
|
|||
token: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
||||
jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
||||
|
||||
last_used_at: Optional[datetime] = attr.ib(default=None)
|
||||
last_used_ip: Optional[str] = attr.ib(default=None)
|
||||
last_used_at: datetime | None = attr.ib(default=None)
|
||||
last_used_ip: str | None = attr.ib(default=None)
|
||||
|
||||
credential: Optional["Credentials"] = attr.ib(default=None)
|
||||
credential: "Credentials" | None = attr.ib(default=None)
|
||||
|
||||
version: Optional[str] = attr.ib(default=__version__)
|
||||
version: str | None = attr.ib(default=__version__)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -117,7 +119,7 @@ class Credentials:
|
|||
"""Credentials for a user on an auth provider."""
|
||||
|
||||
auth_provider_type: str = attr.ib()
|
||||
auth_provider_id: Optional[str] = attr.ib()
|
||||
auth_provider_id: str | None = attr.ib()
|
||||
|
||||
# Allow the auth provider to store data to represent their auth.
|
||||
data: dict = attr.ib()
|
||||
|
@ -129,5 +131,5 @@ class Credentials:
|
|||
class UserMeta(NamedTuple):
|
||||
"""User metadata."""
|
||||
|
||||
name: Optional[str]
|
||||
name: str | None
|
||||
is_active: bool
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Permissions for Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -19,7 +21,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
class AbstractPermissions:
|
||||
"""Default permissions class."""
|
||||
|
||||
_cached_entity_func: Optional[Callable[[str, str], bool]] = None
|
||||
_cached_entity_func: Callable[[str, str], bool] | None = None
|
||||
|
||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||
"""Return a function that can test entity access."""
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Entity permissions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -43,14 +45,14 @@ ENTITY_POLICY_SCHEMA = vol.Any(
|
|||
|
||||
def _lookup_domain(
|
||||
perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
|
||||
) -> Optional[ValueType]:
|
||||
) -> ValueType | None:
|
||||
"""Look up entity permissions by domain."""
|
||||
return domains_dict.get(entity_id.split(".", 1)[0])
|
||||
|
||||
|
||||
def _lookup_area(
|
||||
perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
|
||||
) -> Optional[ValueType]:
|
||||
) -> ValueType | None:
|
||||
"""Look up entity permissions by area."""
|
||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||
|
||||
|
@ -67,7 +69,7 @@ def _lookup_area(
|
|||
|
||||
def _lookup_device(
|
||||
perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
|
||||
) -> Optional[ValueType]:
|
||||
) -> ValueType | None:
|
||||
"""Look up entity permissions by device."""
|
||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||
|
||||
|
@ -79,7 +81,7 @@ def _lookup_device(
|
|||
|
||||
def _lookup_entity_id(
|
||||
perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
|
||||
) -> Optional[ValueType]:
|
||||
) -> ValueType | None:
|
||||
"""Look up entity permission by entity id."""
|
||||
return entities_dict.get(entity_id)
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
"""Merging of policies."""
|
||||
from typing import Dict, List, Set, cast
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from .types import CategoryType, PolicyType
|
||||
|
||||
|
||||
def merge_policies(policies: List[PolicyType]) -> PolicyType:
|
||||
def merge_policies(policies: list[PolicyType]) -> PolicyType:
|
||||
"""Merge policies."""
|
||||
new_policy: Dict[str, CategoryType] = {}
|
||||
seen: Set[str] = set()
|
||||
new_policy: dict[str, CategoryType] = {}
|
||||
seen: set[str] = set()
|
||||
for policy in policies:
|
||||
for category in policy:
|
||||
if category in seen:
|
||||
|
@ -20,7 +22,7 @@ def merge_policies(policies: List[PolicyType]) -> PolicyType:
|
|||
return new_policy
|
||||
|
||||
|
||||
def _merge_policies(sources: List[CategoryType]) -> CategoryType:
|
||||
def _merge_policies(sources: list[CategoryType]) -> CategoryType:
|
||||
"""Merge a policy."""
|
||||
# When merging policies, the most permissive wins.
|
||||
# This means we order it like this:
|
||||
|
@ -34,7 +36,7 @@ def _merge_policies(sources: List[CategoryType]) -> CategoryType:
|
|||
# merge each key in the source.
|
||||
|
||||
policy: CategoryType = None
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
for source in sources:
|
||||
if source is None:
|
||||
continue
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Helpers to deal with permissions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, List, Optional, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
from .const import SUBCAT_ALL
|
||||
from .models import PermissionLookup
|
||||
|
@ -45,7 +47,7 @@ def compile_policy(
|
|||
|
||||
assert isinstance(policy, dict)
|
||||
|
||||
funcs: List[Callable[[str, str], Optional[bool]]] = []
|
||||
funcs: list[Callable[[str, str], bool | None]] = []
|
||||
|
||||
for key, lookup_func in subcategories.items():
|
||||
lookup_value = policy.get(key)
|
||||
|
@ -80,10 +82,10 @@ def compile_policy(
|
|||
|
||||
def _gen_dict_test_func(
|
||||
perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
|
||||
) -> Callable[[str, str], Optional[bool]]:
|
||||
) -> Callable[[str, str], bool | None]:
|
||||
"""Generate a lookup function."""
|
||||
|
||||
def test_value(object_id: str, key: str) -> Optional[bool]:
|
||||
def test_value(object_id: str, key: str) -> bool | None:
|
||||
"""Test if permission is allowed based on the keys."""
|
||||
schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import importlib
|
||||
import logging
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
@ -42,7 +42,7 @@ class AuthProvider:
|
|||
DEFAULT_TITLE = "Unnamed auth provider"
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
||||
self, hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
|
||||
) -> None:
|
||||
"""Initialize an auth provider."""
|
||||
self.hass = hass
|
||||
|
@ -50,7 +50,7 @@ class AuthProvider:
|
|||
self.config = config
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
def id(self) -> str | None:
|
||||
"""Return id of the auth provider.
|
||||
|
||||
Optional, can be None.
|
||||
|
@ -72,7 +72,7 @@ class AuthProvider:
|
|||
"""Return whether multi-factor auth supported by the auth provider."""
|
||||
return True
|
||||
|
||||
async def async_credentials(self) -> List[Credentials]:
|
||||
async def async_credentials(self) -> list[Credentials]:
|
||||
"""Return all credentials of this provider."""
|
||||
users = await self.store.async_get_users()
|
||||
return [
|
||||
|
@ -86,7 +86,7 @@ class AuthProvider:
|
|||
]
|
||||
|
||||
@callback
|
||||
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
|
||||
def async_create_credentials(self, data: dict[str, str]) -> Credentials:
|
||||
"""Create credentials."""
|
||||
return Credentials(
|
||||
auth_provider_type=self.type, auth_provider_id=self.id, data=data
|
||||
|
@ -94,7 +94,7 @@ class AuthProvider:
|
|||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return the data flow for logging in with auth provider.
|
||||
|
||||
Auth provider should extend LoginFlow and return an instance.
|
||||
|
@ -102,7 +102,7 @@ class AuthProvider:
|
|||
raise NotImplementedError
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
raise NotImplementedError
|
||||
|
@ -121,7 +121,7 @@ class AuthProvider:
|
|||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||
self, refresh_token: RefreshToken, remote_ip: str | None = None
|
||||
) -> None:
|
||||
"""Verify a refresh token is still valid.
|
||||
|
||||
|
@ -131,7 +131,7 @@ class AuthProvider:
|
|||
|
||||
|
||||
async def auth_provider_from_config(
|
||||
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
||||
hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
|
||||
) -> AuthProvider:
|
||||
"""Initialize an auth provider from a config."""
|
||||
provider_name = config[CONF_TYPE]
|
||||
|
@ -188,17 +188,17 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self._auth_module_id: Optional[str] = None
|
||||
self._auth_module_id: str | None = None
|
||||
self._auth_manager = auth_provider.hass.auth
|
||||
self.available_mfa_modules: Dict[str, str] = {}
|
||||
self.available_mfa_modules: dict[str, str] = {}
|
||||
self.created_at = dt_util.utcnow()
|
||||
self.invalid_mfa_times = 0
|
||||
self.user: Optional[User] = None
|
||||
self.credential: Optional[Credentials] = None
|
||||
self.user: User | None = None
|
||||
self.credential: Credentials | None = None
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the first step of login flow.
|
||||
|
||||
Return self.async_show_form(step_id='init') if user_input is None.
|
||||
|
@ -207,8 +207,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
raise NotImplementedError
|
||||
|
||||
async def async_step_select_mfa_module(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of select mfa module."""
|
||||
errors = {}
|
||||
|
||||
|
@ -232,8 +232,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
)
|
||||
|
||||
async def async_step_mfa(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of mfa validation."""
|
||||
assert self.credential
|
||||
assert self.user
|
||||
|
@ -273,7 +273,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
if not errors:
|
||||
return await self.async_finish(self.credential)
|
||||
|
||||
description_placeholders: Dict[str, Optional[str]] = {
|
||||
description_placeholders: dict[str, str | None] = {
|
||||
"mfa_module_name": auth_module.name,
|
||||
"mfa_module_id": auth_module.id,
|
||||
}
|
||||
|
@ -285,6 +285,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_finish(self, flow_result: Any) -> Dict:
|
||||
async def async_finish(self, flow_result: Any) -> dict:
|
||||
"""Handle the pass of login flow."""
|
||||
return self.async_create_entry(title=self._auth_provider.name, data=flow_result)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
"""Auth provider that validates credentials via an external command."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio.subprocess
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -51,9 +52,9 @@ class CommandLineAuthProvider(AuthProvider):
|
|||
attributes provided by external programs.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta: Dict[str, Dict[str, Any]] = {}
|
||||
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def async_login_flow(self, context: Optional[dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return CommandLineLoginFlow(self)
|
||||
|
||||
|
@ -82,7 +83,7 @@ class CommandLineAuthProvider(AuthProvider):
|
|||
raise InvalidAuthError
|
||||
|
||||
if self.config[CONF_META]:
|
||||
meta: Dict[str, str] = {}
|
||||
meta: dict[str, str] = {}
|
||||
for _line in stdout.splitlines():
|
||||
try:
|
||||
line = _line.decode().lstrip()
|
||||
|
@ -99,7 +100,7 @@ class CommandLineAuthProvider(AuthProvider):
|
|||
self._user_meta[username] = meta
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
|
@ -125,8 +126,8 @@ class CommandLineLoginFlow(LoginFlow):
|
|||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
@ -143,7 +144,7 @@ class CommandLineLoginFlow(LoginFlow):
|
|||
user_input.pop("password")
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema: Dict[str, type] = collections.OrderedDict()
|
||||
schema: dict[str, type] = collections.OrderedDict()
|
||||
schema["username"] = str
|
||||
schema["password"] = str
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
import base64
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import bcrypt
|
||||
import voluptuous as vol
|
||||
|
@ -21,7 +21,7 @@ STORAGE_VERSION = 1
|
|||
STORAGE_KEY = "auth_provider.homeassistant"
|
||||
|
||||
|
||||
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _disallow_id(conf: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Disallow ID in config."""
|
||||
if CONF_ID in conf:
|
||||
raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
|
||||
|
@ -62,7 +62,7 @@ class Data:
|
|||
self._store = hass.helpers.storage.Store(
|
||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||
)
|
||||
self._data: Optional[Dict[str, Any]] = None
|
||||
self._data: dict[str, Any] | None = None
|
||||
# Legacy mode will allow usernames to start/end with whitespace
|
||||
# and will compare usernames case-insensitive.
|
||||
# Remove in 2020 or when we launch 1.0.
|
||||
|
@ -83,7 +83,7 @@ class Data:
|
|||
if data is None:
|
||||
data = {"users": []}
|
||||
|
||||
seen: Set[str] = set()
|
||||
seen: set[str] = set()
|
||||
|
||||
for user in data["users"]:
|
||||
username = user["username"]
|
||||
|
@ -121,7 +121,7 @@ class Data:
|
|||
self._data = data
|
||||
|
||||
@property
|
||||
def users(self) -> List[Dict[str, str]]:
|
||||
def users(self) -> list[dict[str, str]]:
|
||||
"""Return users."""
|
||||
return self._data["users"] # type: ignore
|
||||
|
||||
|
@ -220,7 +220,7 @@ class HassAuthProvider(AuthProvider):
|
|||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Initialize an Home Assistant auth provider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data: Optional[Data] = None
|
||||
self.data: Data | None = None
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
|
@ -233,7 +233,7 @@ class HassAuthProvider(AuthProvider):
|
|||
await data.async_load()
|
||||
self.data = data
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return HassLoginFlow(self)
|
||||
|
||||
|
@ -277,7 +277,7 @@ class HassAuthProvider(AuthProvider):
|
|||
await self.data.async_save()
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
if self.data is None:
|
||||
|
@ -318,8 +318,8 @@ class HassLoginFlow(LoginFlow):
|
|||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
@ -335,7 +335,7 @@ class HassLoginFlow(LoginFlow):
|
|||
user_input.pop("password")
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema: Dict[str, type] = OrderedDict()
|
||||
schema: dict[str, type] = OrderedDict()
|
||||
schema["username"] = str
|
||||
schema["password"] = str
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Example auth provider."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -33,7 +35,7 @@ class InvalidAuthError(HomeAssistantError):
|
|||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return ExampleLoginFlow(self)
|
||||
|
||||
|
@ -60,7 +62,7 @@ class ExampleAuthProvider(AuthProvider):
|
|||
raise InvalidAuthError
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
|
@ -94,8 +96,8 @@ class ExampleLoginFlow(LoginFlow):
|
|||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
@ -111,7 +113,7 @@ class ExampleLoginFlow(LoginFlow):
|
|||
user_input.pop("password")
|
||||
return await self.async_finish(user_input)
|
||||
|
||||
schema: Dict[str, type] = OrderedDict()
|
||||
schema: dict[str, type] = OrderedDict()
|
||||
schema["username"] = str
|
||||
schema["password"] = str
|
||||
|
||||
|
|
|
@ -3,8 +3,10 @@ Support Legacy API password auth provider.
|
|||
|
||||
It will be removed when auth system production ready
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -40,7 +42,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
"""Return api_password."""
|
||||
return str(self.config[CONF_API_PASSWORD])
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return LegacyLoginFlow(self)
|
||||
|
||||
|
@ -55,7 +57,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
|||
raise InvalidAuthError
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Return credentials for this login."""
|
||||
credentials = await self.async_credentials()
|
||||
|
@ -79,8 +81,8 @@ class LegacyLoginFlow(LoginFlow):
|
|||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
errors = {}
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
It shows list of users if access from trusted network.
|
||||
Abort login flow if not access from trusted network.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Network,
|
||||
|
@ -11,7 +13,7 @@ from ipaddress import (
|
|||
ip_address,
|
||||
ip_network,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -68,12 +70,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
DEFAULT_TITLE = "Trusted Networks"
|
||||
|
||||
@property
|
||||
def trusted_networks(self) -> List[IPNetwork]:
|
||||
def trusted_networks(self) -> list[IPNetwork]:
|
||||
"""Return trusted networks."""
|
||||
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
|
||||
|
||||
@property
|
||||
def trusted_users(self) -> Dict[IPNetwork, Any]:
|
||||
def trusted_users(self) -> dict[IPNetwork, Any]:
|
||||
"""Return trusted users per network."""
|
||||
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
|
||||
|
||||
|
@ -82,7 +84,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
"""Trusted Networks auth provider does not support MFA."""
|
||||
return False
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||
|
@ -125,7 +127,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
)
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Dict[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
user_id = flow_result["user"]
|
||||
|
@ -169,7 +171,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||
|
||||
@callback
|
||||
def async_validate_refresh_token(
|
||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||
self, refresh_token: RefreshToken, remote_ip: str | None = None
|
||||
) -> None:
|
||||
"""Verify a refresh token is still valid."""
|
||||
if remote_ip is None:
|
||||
|
@ -186,7 +188,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
|||
self,
|
||||
auth_provider: TrustedNetworksAuthProvider,
|
||||
ip_addr: IPAddress,
|
||||
available_users: Dict[str, Optional[str]],
|
||||
available_users: dict[str, str | None],
|
||||
allow_bypass_login: bool,
|
||||
) -> None:
|
||||
"""Initialize the login flow."""
|
||||
|
@ -196,8 +198,8 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
|||
self._allow_bypass_login = allow_bypass_login
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle the step of the form."""
|
||||
try:
|
||||
cast(
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
"""Home Assistant command line scripts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional, Sequence, Text
|
||||
from typing import Sequence
|
||||
|
||||
from homeassistant import runner
|
||||
from homeassistant.bootstrap import async_mount_local_lib_path
|
||||
|
@ -16,7 +18,7 @@ from homeassistant.util.package import install_package, is_installed, is_virtual
|
|||
# mypy: allow-untyped-defs, no-warn-return-any
|
||||
|
||||
|
||||
def run(args: List) -> int:
|
||||
def run(args: list) -> int:
|
||||
"""Run a script."""
|
||||
scripts = []
|
||||
path = os.path.dirname(__file__)
|
||||
|
@ -65,7 +67,7 @@ def run(args: List) -> int:
|
|||
return script.run(args[1:]) # type: ignore
|
||||
|
||||
|
||||
def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str:
|
||||
def extract_config_dir(args: Sequence[str] | None = None) -> str:
|
||||
"""Extract the config dir from the arguments or get the default."""
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("-c", "--config", default=None)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Script to run benchmarks."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import collections
|
||||
|
@ -7,7 +9,7 @@ from datetime import datetime
|
|||
import json
|
||||
import logging
|
||||
from timeit import default_timer as timer
|
||||
from typing import Callable, Dict, TypeVar
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from homeassistant import core
|
||||
from homeassistant.components.websocket_api.const import JSON_DUMP
|
||||
|
@ -21,7 +23,7 @@ from homeassistant.util import dt as dt_util
|
|||
|
||||
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
BENCHMARKS: Dict[str, Callable] = {}
|
||||
BENCHMARKS: dict[str, Callable] = {}
|
||||
|
||||
|
||||
def run(args):
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Script to check the configuration file."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
|
@ -6,7 +8,7 @@ from collections.abc import Mapping, Sequence
|
|||
from glob import glob
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import core
|
||||
|
@ -22,13 +24,13 @@ REQUIREMENTS = ("colorlog==4.7.2",)
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
# pylint: disable=protected-access
|
||||
MOCKS: Dict[str, Tuple[str, Callable]] = {
|
||||
MOCKS: dict[str, tuple[str, Callable]] = {
|
||||
"load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml),
|
||||
"load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml),
|
||||
"secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml),
|
||||
}
|
||||
|
||||
PATCHES: Dict[str, Any] = {}
|
||||
PATCHES: dict[str, Any] = {}
|
||||
|
||||
C_HEAD = "bold"
|
||||
ERROR_STR = "General Errors"
|
||||
|
@ -48,7 +50,7 @@ def color(the_color, *args, reset=None):
|
|||
raise ValueError(f"Invalid color {k!s} in {the_color}") from k
|
||||
|
||||
|
||||
def run(script_args: List) -> int:
|
||||
def run(script_args: list) -> int:
|
||||
"""Handle check config commandline script."""
|
||||
parser = argparse.ArgumentParser(description="Check Home Assistant configuration.")
|
||||
parser.add_argument("--script", choices=["check_config"])
|
||||
|
@ -83,7 +85,7 @@ def run(script_args: List) -> int:
|
|||
|
||||
res = check(config_dir, args.secrets)
|
||||
|
||||
domain_info: List[str] = []
|
||||
domain_info: list[str] = []
|
||||
if args.info:
|
||||
domain_info = args.info.split(",")
|
||||
|
||||
|
@ -123,7 +125,7 @@ def run(script_args: List) -> int:
|
|||
dump_dict(res["components"].get(domain))
|
||||
|
||||
if args.secrets:
|
||||
flatsecret: Dict[str, str] = {}
|
||||
flatsecret: dict[str, str] = {}
|
||||
|
||||
for sfn, sdict in res["secret_cache"].items():
|
||||
sss = []
|
||||
|
@ -149,7 +151,7 @@ def run(script_args: List) -> int:
|
|||
def check(config_dir, secrets=False):
|
||||
"""Perform a check by mocking hass load functions."""
|
||||
logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL)
|
||||
res: Dict[str, Any] = {
|
||||
res: dict[str, Any] = {
|
||||
"yaml_files": OrderedDict(), # yaml_files loaded
|
||||
"secrets": OrderedDict(), # secret cache and secrets loaded
|
||||
"except": OrderedDict(), # exceptions raised (with config)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Helper methods for various modules."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import enum
|
||||
|
@ -9,16 +11,7 @@ import socket
|
|||
import string
|
||||
import threading
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
KeysView,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Coroutine, Iterable, KeysView, TypeVar
|
||||
|
||||
import slugify as unicode_slug
|
||||
|
||||
|
@ -106,8 +99,8 @@ def repr_helper(inp: Any) -> str:
|
|||
|
||||
|
||||
def convert(
|
||||
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
|
||||
) -> Optional[U]:
|
||||
value: T | None, to_type: Callable[[T], U], default: U | None = None
|
||||
) -> U | None:
|
||||
"""Convert value to to_type, returns default if fails."""
|
||||
try:
|
||||
return default if value is None else to_type(value)
|
||||
|
@ -117,7 +110,7 @@ def convert(
|
|||
|
||||
|
||||
def ensure_unique_string(
|
||||
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
|
||||
preferred_string: str, current_strings: Iterable[str] | KeysView[str]
|
||||
) -> str:
|
||||
"""Return a string that is not present in current_strings.
|
||||
|
||||
|
@ -213,7 +206,7 @@ class Throttle:
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
|
||||
self, min_time: timedelta, limit_no_throttle: timedelta | None = None
|
||||
) -> None:
|
||||
"""Initialize the throttle."""
|
||||
self.min_time = min_time
|
||||
|
@ -253,7 +246,7 @@ class Throttle:
|
|||
)
|
||||
|
||||
@wraps(method)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Callable | Coroutine:
|
||||
"""Wrap that allows wrapped to be called only once per min_time.
|
||||
|
||||
If we cannot acquire the lock, it is running so return None.
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Utilities to help with aiohttp."""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qsl
|
||||
|
||||
from multidict import CIMultiDict, MultiDict
|
||||
|
@ -26,7 +28,7 @@ class MockStreamReader:
|
|||
class MockRequest:
|
||||
"""Mock an aiohttp request."""
|
||||
|
||||
mock_source: Optional[str] = None
|
||||
mock_source: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -34,8 +36,8 @@ class MockRequest:
|
|||
mock_source: str,
|
||||
method: str = "GET",
|
||||
status: int = HTTP_OK,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
query_string: Optional[str] = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
query_string: str | None = None,
|
||||
url: str = "",
|
||||
) -> None:
|
||||
"""Initialize a request."""
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
"""Color util methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
import colorsys
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -183,7 +184,7 @@ class GamutType:
|
|||
blue: XYPoint = attr.ib()
|
||||
|
||||
|
||||
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
|
||||
def color_name_to_rgb(color_name: str) -> tuple[int, int, int]:
|
||||
"""Convert color name to RGB hex value."""
|
||||
# COLORS map has no spaces in it, so make the color_name have no
|
||||
# spaces in it as well for matching purposes
|
||||
|
@ -198,8 +199,8 @@ def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
|
|||
|
||||
|
||||
def color_RGB_to_xy(
|
||||
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[float, float]:
|
||||
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
|
||||
) -> tuple[float, float]:
|
||||
"""Convert from RGB color to XY color."""
|
||||
return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2]
|
||||
|
||||
|
@ -208,8 +209,8 @@ def color_RGB_to_xy(
|
|||
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
||||
# License: Code is given as is. Use at your own risk and discretion.
|
||||
def color_RGB_to_xy_brightness(
|
||||
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[float, float, int]:
|
||||
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
|
||||
) -> tuple[float, float, int]:
|
||||
"""Convert from RGB color to XY color."""
|
||||
if iR + iG + iB == 0:
|
||||
return 0.0, 0.0, 0
|
||||
|
@ -248,8 +249,8 @@ def color_RGB_to_xy_brightness(
|
|||
|
||||
|
||||
def color_xy_to_RGB(
|
||||
vX: float, vY: float, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[int, int, int]:
|
||||
vX: float, vY: float, Gamut: GamutType | None = None
|
||||
) -> tuple[int, int, int]:
|
||||
"""Convert from XY to a normalized RGB."""
|
||||
return color_xy_brightness_to_RGB(vX, vY, 255, Gamut)
|
||||
|
||||
|
@ -257,8 +258,8 @@ def color_xy_to_RGB(
|
|||
# Converted to Python from Obj-C, original source from:
|
||||
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
||||
def color_xy_brightness_to_RGB(
|
||||
vX: float, vY: float, ibrightness: int, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[int, int, int]:
|
||||
vX: float, vY: float, ibrightness: int, Gamut: GamutType | None = None
|
||||
) -> tuple[int, int, int]:
|
||||
"""Convert from XYZ to RGB."""
|
||||
if Gamut:
|
||||
if not check_point_in_lamps_reach((vX, vY), Gamut):
|
||||
|
@ -304,7 +305,7 @@ def color_xy_brightness_to_RGB(
|
|||
return (ir, ig, ib)
|
||||
|
||||
|
||||
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
|
||||
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> tuple[int, int, int]:
|
||||
"""Convert a hsb into its rgb representation."""
|
||||
if fS == 0.0:
|
||||
fV = int(fB * 255)
|
||||
|
@ -345,7 +346,7 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
|
|||
return (r, g, b)
|
||||
|
||||
|
||||
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
|
||||
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> tuple[float, float, float]:
|
||||
"""Convert an rgb color to its hsv representation.
|
||||
|
||||
Hue is scaled 0-360
|
||||
|
@ -356,12 +357,12 @@ def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, flo
|
|||
return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3)
|
||||
|
||||
|
||||
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
|
||||
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> tuple[float, float]:
|
||||
"""Convert an rgb color to its hs representation."""
|
||||
return color_RGB_to_hsv(iR, iG, iB)[:2]
|
||||
|
||||
|
||||
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
|
||||
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> tuple[int, int, int]:
|
||||
"""Convert an hsv color into its rgb representation.
|
||||
|
||||
Hue is scaled 0-360
|
||||
|
@ -372,27 +373,27 @@ def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
|
|||
return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255))
|
||||
|
||||
|
||||
def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
|
||||
def color_hs_to_RGB(iH: float, iS: float) -> tuple[int, int, int]:
|
||||
"""Convert an hsv color into its rgb representation."""
|
||||
return color_hsv_to_RGB(iH, iS, 100)
|
||||
|
||||
|
||||
def color_xy_to_hs(
|
||||
vX: float, vY: float, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[float, float]:
|
||||
vX: float, vY: float, Gamut: GamutType | None = None
|
||||
) -> tuple[float, float]:
|
||||
"""Convert an xy color to its hs representation."""
|
||||
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut))
|
||||
return h, s
|
||||
|
||||
|
||||
def color_hs_to_xy(
|
||||
iH: float, iS: float, Gamut: Optional[GamutType] = None
|
||||
) -> Tuple[float, float]:
|
||||
iH: float, iS: float, Gamut: GamutType | None = None
|
||||
) -> tuple[float, float]:
|
||||
"""Convert an hs color to its xy representation."""
|
||||
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut)
|
||||
|
||||
|
||||
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
|
||||
def _match_max_scale(input_colors: tuple, output_colors: tuple) -> tuple:
|
||||
"""Match the maximum value of the output to the input."""
|
||||
max_in = max(input_colors)
|
||||
max_out = max(output_colors)
|
||||
|
@ -403,7 +404,7 @@ def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
|
|||
return tuple(int(round(i * factor)) for i in output_colors)
|
||||
|
||||
|
||||
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
|
||||
def color_rgb_to_rgbw(r: int, g: int, b: int) -> tuple[int, int, int, int]:
|
||||
"""Convert an rgb color to an rgbw representation."""
|
||||
# Calculate the white channel as the minimum of input rgb channels.
|
||||
# Subtract the white portion from the remaining rgb channels.
|
||||
|
@ -415,7 +416,7 @@ def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
|
|||
return _match_max_scale((r, g, b), rgbw) # type: ignore
|
||||
|
||||
|
||||
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
|
||||
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> tuple[int, int, int]:
|
||||
"""Convert an rgbw color to an rgb representation."""
|
||||
# Add the white channel back into the rgb channels.
|
||||
rgb = (r + w, g + w, b + w)
|
||||
|
@ -430,7 +431,7 @@ def color_rgb_to_hex(r: int, g: int, b: int) -> str:
|
|||
return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b))
|
||||
|
||||
|
||||
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
|
||||
def rgb_hex_to_rgb_list(hex_string: str) -> list[int]:
|
||||
"""Return an RGB color value list from a hex color string."""
|
||||
return [
|
||||
int(hex_string[i : i + len(hex_string) // 3], 16)
|
||||
|
@ -438,14 +439,14 @@ def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
|
|||
]
|
||||
|
||||
|
||||
def color_temperature_to_hs(color_temperature_kelvin: float) -> Tuple[float, float]:
|
||||
def color_temperature_to_hs(color_temperature_kelvin: float) -> tuple[float, float]:
|
||||
"""Return an hs color from a color temperature in Kelvin."""
|
||||
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
|
||||
|
||||
|
||||
def color_temperature_to_rgb(
|
||||
color_temperature_kelvin: float,
|
||||
) -> Tuple[float, float, float]:
|
||||
) -> tuple[float, float, float]:
|
||||
"""
|
||||
Return an RGB color from a color temperature in Kelvin.
|
||||
|
||||
|
@ -555,8 +556,8 @@ def get_closest_point_to_line(A: XYPoint, B: XYPoint, P: XYPoint) -> XYPoint:
|
|||
|
||||
|
||||
def get_closest_point_to_point(
|
||||
xy_tuple: Tuple[float, float], Gamut: GamutType
|
||||
) -> Tuple[float, float]:
|
||||
xy_tuple: tuple[float, float], Gamut: GamutType
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Get the closest matching color within the gamut of the light.
|
||||
|
||||
|
@ -592,7 +593,7 @@ def get_closest_point_to_point(
|
|||
return (cx, cy)
|
||||
|
||||
|
||||
def check_point_in_lamps_reach(p: Tuple[float, float], Gamut: GamutType) -> bool:
|
||||
def check_point_in_lamps_reach(p: tuple[float, float], Gamut: GamutType) -> bool:
|
||||
"""Check if the provided XYPoint can be recreated by a Hue lamp."""
|
||||
v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y)
|
||||
v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Distance util functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable
|
||||
|
||||
from homeassistant.const import (
|
||||
LENGTH,
|
||||
|
@ -26,7 +28,7 @@ VALID_UNITS = [
|
|||
LENGTH_YARD,
|
||||
]
|
||||
|
||||
TO_METERS: Dict[str, Callable[[float], float]] = {
|
||||
TO_METERS: dict[str, Callable[[float], float]] = {
|
||||
LENGTH_METERS: lambda meters: meters,
|
||||
LENGTH_MILES: lambda miles: miles * 1609.344,
|
||||
LENGTH_YARD: lambda yards: yards * 0.9144,
|
||||
|
@ -37,7 +39,7 @@ TO_METERS: Dict[str, Callable[[float], float]] = {
|
|||
LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001,
|
||||
}
|
||||
|
||||
METERS_TO: Dict[str, Callable[[float], float]] = {
|
||||
METERS_TO: dict[str, Callable[[float], float]] = {
|
||||
LENGTH_METERS: lambda meters: meters,
|
||||
LENGTH_MILES: lambda meters: meters * 0.000621371,
|
||||
LENGTH_YARD: lambda meters: meters * 1.09361,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Helper methods to handle the time in Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import ciso8601
|
||||
import pytz
|
||||
|
@ -40,7 +42,7 @@ def set_default_time_zone(time_zone: dt.tzinfo) -> None:
|
|||
DEFAULT_TIME_ZONE = time_zone
|
||||
|
||||
|
||||
def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]:
|
||||
def get_time_zone(time_zone_str: str) -> dt.tzinfo | None:
|
||||
"""Get time zone from string. Return None if unable to determine.
|
||||
|
||||
Async friendly.
|
||||
|
@ -56,7 +58,7 @@ def utcnow() -> dt.datetime:
|
|||
return dt.datetime.now(NATIVE_UTC)
|
||||
|
||||
|
||||
def now(time_zone: Optional[dt.tzinfo] = None) -> dt.datetime:
|
||||
def now(time_zone: dt.tzinfo | None = None) -> dt.datetime:
|
||||
"""Get now in specified time zone."""
|
||||
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
|
||||
|
||||
|
@ -77,7 +79,7 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
|
|||
def as_timestamp(dt_value: dt.datetime) -> float:
|
||||
"""Convert a date/time into a unix time (seconds since 1970)."""
|
||||
if hasattr(dt_value, "timestamp"):
|
||||
parsed_dt: Optional[dt.datetime] = dt_value
|
||||
parsed_dt: dt.datetime | None = dt_value
|
||||
else:
|
||||
parsed_dt = parse_datetime(str(dt_value))
|
||||
if parsed_dt is None:
|
||||
|
@ -100,9 +102,7 @@ def utc_from_timestamp(timestamp: float) -> dt.datetime:
|
|||
return UTC.localize(dt.datetime.utcfromtimestamp(timestamp))
|
||||
|
||||
|
||||
def start_of_local_day(
|
||||
dt_or_d: Union[dt.date, dt.datetime, None] = None
|
||||
) -> dt.datetime:
|
||||
def start_of_local_day(dt_or_d: dt.date | dt.datetime | None = None) -> dt.datetime:
|
||||
"""Return local datetime object of start of day from date or datetime."""
|
||||
if dt_or_d is None:
|
||||
date: dt.date = now().date()
|
||||
|
@ -119,7 +119,7 @@ def start_of_local_day(
|
|||
# Copyright (c) Django Software Foundation and individual contributors.
|
||||
# All rights reserved.
|
||||
# https://github.com/django/django/blob/master/LICENSE
|
||||
def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
||||
def parse_datetime(dt_str: str) -> dt.datetime | None:
|
||||
"""Parse a string and return a datetime.datetime.
|
||||
|
||||
This function supports time zone offsets. When the input contains one,
|
||||
|
@ -134,12 +134,12 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
|||
match = DATETIME_RE.match(dt_str)
|
||||
if not match:
|
||||
return None
|
||||
kws: Dict[str, Any] = match.groupdict()
|
||||
kws: dict[str, Any] = match.groupdict()
|
||||
if kws["microsecond"]:
|
||||
kws["microsecond"] = kws["microsecond"].ljust(6, "0")
|
||||
tzinfo_str = kws.pop("tzinfo")
|
||||
|
||||
tzinfo: Optional[dt.tzinfo] = None
|
||||
tzinfo: dt.tzinfo | None = None
|
||||
if tzinfo_str == "Z":
|
||||
tzinfo = UTC
|
||||
elif tzinfo_str is not None:
|
||||
|
@ -154,7 +154,7 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
|||
return dt.datetime(**kws)
|
||||
|
||||
|
||||
def parse_date(dt_str: str) -> Optional[dt.date]:
|
||||
def parse_date(dt_str: str) -> dt.date | None:
|
||||
"""Convert a date string to a date object."""
|
||||
try:
|
||||
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
|
||||
|
@ -162,7 +162,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
|
|||
return None
|
||||
|
||||
|
||||
def parse_time(time_str: str) -> Optional[dt.time]:
|
||||
def parse_time(time_str: str) -> dt.time | None:
|
||||
"""Parse a time string (00:20:00) into Time object.
|
||||
|
||||
Return None if invalid.
|
||||
|
@ -213,7 +213,7 @@ def get_age(date: dt.datetime) -> str:
|
|||
return formatn(rounded_delta, selected_unit)
|
||||
|
||||
|
||||
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> List[int]:
|
||||
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> list[int]:
|
||||
"""Parse the time expression part and return a list of times to match."""
|
||||
if parameter is None or parameter == MATCH_ALL:
|
||||
res = list(range(min_value, max_value + 1))
|
||||
|
@ -241,9 +241,9 @@ def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> Lis
|
|||
|
||||
def find_next_time_expression_time(
|
||||
now: dt.datetime, # pylint: disable=redefined-outer-name
|
||||
seconds: List[int],
|
||||
minutes: List[int],
|
||||
hours: List[int],
|
||||
seconds: list[int],
|
||||
minutes: list[int],
|
||||
hours: list[int],
|
||||
) -> dt.datetime:
|
||||
"""Find the next datetime from now for which the time expression matches.
|
||||
|
||||
|
@ -257,7 +257,7 @@ def find_next_time_expression_time(
|
|||
if not seconds or not minutes or not hours:
|
||||
raise ValueError("Cannot find a next time: Time expression never matches!")
|
||||
|
||||
def _lower_bound(arr: List[int], cmp: int) -> Optional[int]:
|
||||
def _lower_bound(arr: list[int], cmp: int) -> int | None:
|
||||
"""Return the first value in arr greater or equal to cmp.
|
||||
|
||||
Return None if no such value exists.
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
"""JSON utility functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from typing import Any, Callable
|
||||
|
||||
from homeassistant.core import Event, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -20,9 +22,7 @@ class WriteError(HomeAssistantError):
|
|||
"""Error writing the data."""
|
||||
|
||||
|
||||
def load_json(
|
||||
filename: str, default: Union[List, Dict, None] = None
|
||||
) -> Union[List, Dict]:
|
||||
def load_json(filename: str, default: list | dict | None = None) -> list | dict:
|
||||
"""Load JSON data from a file and return as dict or list.
|
||||
|
||||
Defaults to returning empty dict if file is not found.
|
||||
|
@ -44,10 +44,10 @@ def load_json(
|
|||
|
||||
def save_json(
|
||||
filename: str,
|
||||
data: Union[List, Dict],
|
||||
data: list | dict,
|
||||
private: bool = False,
|
||||
*,
|
||||
encoder: Optional[Type[json.JSONEncoder]] = None,
|
||||
encoder: type[json.JSONEncoder] | None = None,
|
||||
) -> None:
|
||||
"""Save JSON data to a file.
|
||||
|
||||
|
@ -85,7 +85,7 @@ def save_json(
|
|||
_LOGGER.error("JSON replacement cleanup failed: %s", err)
|
||||
|
||||
|
||||
def format_unserializable_data(data: Dict[str, Any]) -> str:
|
||||
def format_unserializable_data(data: dict[str, Any]) -> str:
|
||||
"""Format output of find_paths in a friendly way.
|
||||
|
||||
Format is comma separated: <path>=<value>(<type>)
|
||||
|
@ -95,7 +95,7 @@ def format_unserializable_data(data: Dict[str, Any]) -> str:
|
|||
|
||||
def find_paths_unserializable_data(
|
||||
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Find the paths to unserializable data.
|
||||
|
||||
This method is slow! Only use for error handling.
|
||||
|
|
|
@ -3,10 +3,12 @@ Module with location helpers.
|
|||
|
||||
detect_location_info and elevation are mocked by default during tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
@ -47,7 +49,7 @@ LocationInfo = collections.namedtuple(
|
|||
|
||||
async def async_detect_location_info(
|
||||
session: aiohttp.ClientSession,
|
||||
) -> Optional[LocationInfo]:
|
||||
) -> LocationInfo | None:
|
||||
"""Detect location information."""
|
||||
data = await _get_ipapi(session)
|
||||
|
||||
|
@ -63,8 +65,8 @@ async def async_detect_location_info(
|
|||
|
||||
|
||||
def distance(
|
||||
lat1: Optional[float], lon1: Optional[float], lat2: float, lon2: float
|
||||
) -> Optional[float]:
|
||||
lat1: float | None, lon1: float | None, lat2: float, lon2: float
|
||||
) -> float | None:
|
||||
"""Calculate the distance in meters between two points.
|
||||
|
||||
Async friendly.
|
||||
|
@ -81,8 +83,8 @@ def distance(
|
|||
# Source: https://github.com/maurycyp/vincenty
|
||||
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
|
||||
def vincenty(
|
||||
point1: Tuple[float, float], point2: Tuple[float, float], miles: bool = False
|
||||
) -> Optional[float]:
|
||||
point1: tuple[float, float], point2: tuple[float, float], miles: bool = False
|
||||
) -> float | None:
|
||||
"""
|
||||
Vincenty formula (inverse method) to calculate the distance.
|
||||
|
||||
|
@ -162,7 +164,7 @@ def vincenty(
|
|||
return round(s, 6)
|
||||
|
||||
|
||||
async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
|
||||
async def _get_ipapi(session: aiohttp.ClientSession) -> dict[str, Any] | None:
|
||||
"""Query ipapi.co for location data."""
|
||||
try:
|
||||
resp = await session.get(IPAPI, timeout=5)
|
||||
|
@ -192,7 +194,7 @@ async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]
|
|||
}
|
||||
|
||||
|
||||
async def _get_ip_api(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
|
||||
async def _get_ip_api(session: aiohttp.ClientSession) -> dict[str, Any] | None:
|
||||
"""Query ip-api.com for location data."""
|
||||
try:
|
||||
resp = await session.get(IP_API, timeout=5)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Logging utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
|
@ -6,7 +8,7 @@ import logging
|
|||
import logging.handlers
|
||||
import queue
|
||||
import traceback
|
||||
from typing import Any, Awaitable, Callable, Coroutine, Union, cast, overload
|
||||
from typing import Any, Awaitable, Callable, Coroutine, cast, overload
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
@ -115,7 +117,7 @@ def catch_log_exception(
|
|||
|
||||
def catch_log_exception(
|
||||
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
|
||||
) -> Union[Callable[..., None], Callable[..., Awaitable[None]]]:
|
||||
) -> Callable[..., None] | Callable[..., Awaitable[None]]:
|
||||
"""Decorate a callback to catch and log exceptions."""
|
||||
|
||||
# Check for partials to properly determine if coroutine function
|
||||
|
@ -123,7 +125,7 @@ def catch_log_exception(
|
|||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func
|
||||
|
||||
wrapper_func: Union[Callable[..., None], Callable[..., Awaitable[None]]]
|
||||
wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]]
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
async_func = cast(Callable[..., Awaitable[None]], func)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Network utilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
|
||||
from typing import Union
|
||||
|
||||
import yarl
|
||||
|
||||
|
@ -23,22 +24,22 @@ PRIVATE_NETWORKS = (
|
|||
LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16")
|
||||
|
||||
|
||||
def is_loopback(address: Union[IPv4Address, IPv6Address]) -> bool:
|
||||
def is_loopback(address: IPv4Address | IPv6Address) -> bool:
|
||||
"""Check if an address is a loopback address."""
|
||||
return any(address in network for network in LOOPBACK_NETWORKS)
|
||||
|
||||
|
||||
def is_private(address: Union[IPv4Address, IPv6Address]) -> bool:
|
||||
def is_private(address: IPv4Address | IPv6Address) -> bool:
|
||||
"""Check if an address is a private address."""
|
||||
return any(address in network for network in PRIVATE_NETWORKS)
|
||||
|
||||
|
||||
def is_link_local(address: Union[IPv4Address, IPv6Address]) -> bool:
|
||||
def is_link_local(address: IPv4Address | IPv6Address) -> bool:
|
||||
"""Check if an address is link local."""
|
||||
return address in LINK_LOCAL_NETWORK
|
||||
|
||||
|
||||
def is_local(address: Union[IPv4Address, IPv6Address]) -> bool:
|
||||
def is_local(address: IPv4Address | IPv6Address) -> bool:
|
||||
"""Check if an address is loopback or private."""
|
||||
return is_loopback(address) or is_private(address)
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Helpers to install PyPi packages."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
import logging
|
||||
|
@ -6,7 +8,6 @@ import os
|
|||
from pathlib import Path
|
||||
from subprocess import PIPE, Popen
|
||||
import sys
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pkg_resources
|
||||
|
@ -59,10 +60,10 @@ def is_installed(package: str) -> bool:
|
|||
def install_package(
|
||||
package: str,
|
||||
upgrade: bool = True,
|
||||
target: Optional[str] = None,
|
||||
constraints: Optional[str] = None,
|
||||
find_links: Optional[str] = None,
|
||||
no_cache_dir: Optional[bool] = False,
|
||||
target: str | None = None,
|
||||
constraints: str | None = None,
|
||||
find_links: str | None = None,
|
||||
no_cache_dir: bool | None = False,
|
||||
) -> bool:
|
||||
"""Install a package on PyPi. Accepts pip compatible package strings.
|
||||
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
"""Percentage util functions."""
|
||||
|
||||
from typing import List, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
|
||||
def ordered_list_item_to_percentage(ordered_list: list[str], item: str) -> int:
|
||||
"""Determine the percentage of an item in an ordered list.
|
||||
|
||||
When using this utility for fan speeds, do not include "off"
|
||||
|
@ -26,7 +25,7 @@ def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
|
|||
return (list_position * 100) // list_len
|
||||
|
||||
|
||||
def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) -> str:
|
||||
def percentage_to_ordered_list_item(ordered_list: list[str], percentage: int) -> str:
|
||||
"""Find the item that most closely matches the percentage in an ordered list.
|
||||
|
||||
When using this utility for fan speeds, do not include "off"
|
||||
|
@ -54,7 +53,7 @@ def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) ->
|
|||
|
||||
|
||||
def ranged_value_to_percentage(
|
||||
low_high_range: Tuple[float, float], value: float
|
||||
low_high_range: tuple[float, float], value: float
|
||||
) -> int:
|
||||
"""Given a range of low and high values convert a single value to a percentage.
|
||||
|
||||
|
@ -71,7 +70,7 @@ def ranged_value_to_percentage(
|
|||
|
||||
|
||||
def percentage_to_ranged_value(
|
||||
low_high_range: Tuple[float, float], percentage: int
|
||||
low_high_range: tuple[float, float], percentage: int
|
||||
) -> float:
|
||||
"""Given a range of low and high values convert a percentage to a single value.
|
||||
|
||||
|
@ -87,11 +86,11 @@ def percentage_to_ranged_value(
|
|||
return states_in_range(low_high_range) * percentage / 100
|
||||
|
||||
|
||||
def states_in_range(low_high_range: Tuple[float, float]) -> float:
|
||||
def states_in_range(low_high_range: tuple[float, float]) -> float:
|
||||
"""Given a range of low and high values return how many states exist."""
|
||||
return low_high_range[1] - low_high_range[0] + 1
|
||||
|
||||
|
||||
def int_states_in_range(low_high_range: Tuple[float, float]) -> int:
|
||||
def int_states_in_range(low_high_range: tuple[float, float]) -> int:
|
||||
"""Given a range of low and high values return how many integer states exist."""
|
||||
return int(states_in_range(low_high_range))
|
||||
|
|
|
@ -2,18 +2,18 @@
|
|||
|
||||
Can only be used by integrations that have pillow in their requirements.
|
||||
"""
|
||||
from typing import Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from PIL import ImageDraw
|
||||
|
||||
|
||||
def draw_box(
|
||||
draw: ImageDraw,
|
||||
box: Tuple[float, float, float, float],
|
||||
box: tuple[float, float, float, float],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
text: str = "",
|
||||
color: Tuple[int, int, int] = (255, 255, 0),
|
||||
color: tuple[int, int, int] = (255, 255, 0),
|
||||
) -> None:
|
||||
"""
|
||||
Draw a bounding box on and image.
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"""ruamel.yaml utility functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import os
|
||||
from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import ruamel.yaml
|
||||
from ruamel.yaml import YAML # type: ignore
|
||||
|
@ -22,7 +24,7 @@ JSON_TYPE = Union[List, Dict, str] # pylint: disable=invalid-name
|
|||
class ExtSafeConstructor(SafeConstructor):
|
||||
"""Extended SafeConstructor."""
|
||||
|
||||
name: Optional[str] = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class UnsupportedYamlError(HomeAssistantError):
|
||||
|
@ -77,7 +79,7 @@ def yaml_to_object(data: str) -> JSON_TYPE:
|
|||
"""Create object from yaml string."""
|
||||
yaml = YAML(typ="rt")
|
||||
try:
|
||||
result: Union[List, Dict, str] = yaml.load(data)
|
||||
result: list | dict | str = yaml.load(data)
|
||||
return result
|
||||
except YAMLError as exc:
|
||||
_LOGGER.error("YAML error: %s", exc)
|
||||
|
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import enum
|
||||
from types import TracebackType
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from typing import Any
|
||||
|
||||
from .async_ import run_callback_threadsafe
|
||||
|
||||
|
@ -38,10 +38,10 @@ class _GlobalFreezeContext:
|
|||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._exit()
|
||||
return None
|
||||
|
||||
|
@ -51,10 +51,10 @@ class _GlobalFreezeContext:
|
|||
|
||||
def __exit__( # pylint: disable=useless-return
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._loop.call_soon_threadsafe(self._exit)
|
||||
return None
|
||||
|
||||
|
@ -106,10 +106,10 @@ class _ZoneFreezeContext:
|
|||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._exit()
|
||||
return None
|
||||
|
||||
|
@ -119,10 +119,10 @@ class _ZoneFreezeContext:
|
|||
|
||||
def __exit__( # pylint: disable=useless-return
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._loop.call_soon_threadsafe(self._exit)
|
||||
return None
|
||||
|
||||
|
@ -155,8 +155,8 @@ class _GlobalTaskContext:
|
|||
self._manager: TimeoutManager = manager
|
||||
self._task: asyncio.Task[Any] = task
|
||||
self._time_left: float = timeout
|
||||
self._expiration_time: Optional[float] = None
|
||||
self._timeout_handler: Optional[asyncio.Handle] = None
|
||||
self._expiration_time: float | None = None
|
||||
self._timeout_handler: asyncio.Handle | None = None
|
||||
self._wait_zone: asyncio.Event = asyncio.Event()
|
||||
self._state: _State = _State.INIT
|
||||
self._cool_down: float = cool_down
|
||||
|
@ -169,10 +169,10 @@ class _GlobalTaskContext:
|
|||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._stop_timer()
|
||||
self._manager.global_tasks.remove(self)
|
||||
|
||||
|
@ -263,8 +263,8 @@ class _ZoneTaskContext:
|
|||
self._task: asyncio.Task[Any] = task
|
||||
self._state: _State = _State.INIT
|
||||
self._time_left: float = timeout
|
||||
self._expiration_time: Optional[float] = None
|
||||
self._timeout_handler: Optional[asyncio.Handle] = None
|
||||
self._expiration_time: float | None = None
|
||||
self._timeout_handler: asyncio.Handle | None = None
|
||||
|
||||
@property
|
||||
def state(self) -> _State:
|
||||
|
@ -283,10 +283,10 @@ class _ZoneTaskContext:
|
|||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Type[BaseException],
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
) -> Optional[bool]:
|
||||
) -> bool | None:
|
||||
self._zone.exit_task(self)
|
||||
self._stop_timer()
|
||||
|
||||
|
@ -344,8 +344,8 @@ class _ZoneTimeoutManager:
|
|||
"""Initialize internal timeout context manager."""
|
||||
self._manager: TimeoutManager = manager
|
||||
self._zone: str = zone
|
||||
self._tasks: List[_ZoneTaskContext] = []
|
||||
self._freezes: List[_ZoneFreezeContext] = []
|
||||
self._tasks: list[_ZoneTaskContext] = []
|
||||
self._freezes: list[_ZoneFreezeContext] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Representation of a zone."""
|
||||
|
@ -418,9 +418,9 @@ class TimeoutManager:
|
|||
def __init__(self) -> None:
|
||||
"""Initialize TimeoutManager."""
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
||||
self._zones: Dict[str, _ZoneTimeoutManager] = {}
|
||||
self._globals: List[_GlobalTaskContext] = []
|
||||
self._freezes: List[_GlobalFreezeContext] = []
|
||||
self._zones: dict[str, _ZoneTimeoutManager] = {}
|
||||
self._globals: list[_GlobalTaskContext] = []
|
||||
self._freezes: list[_GlobalFreezeContext] = []
|
||||
|
||||
@property
|
||||
def zones_done(self) -> bool:
|
||||
|
@ -433,17 +433,17 @@ class TimeoutManager:
|
|||
return not self._freezes
|
||||
|
||||
@property
|
||||
def zones(self) -> Dict[str, _ZoneTimeoutManager]:
|
||||
def zones(self) -> dict[str, _ZoneTimeoutManager]:
|
||||
"""Return all Zones."""
|
||||
return self._zones
|
||||
|
||||
@property
|
||||
def global_tasks(self) -> List[_GlobalTaskContext]:
|
||||
def global_tasks(self) -> list[_GlobalTaskContext]:
|
||||
"""Return all global Tasks."""
|
||||
return self._globals
|
||||
|
||||
@property
|
||||
def global_freezes(self) -> List[_GlobalFreezeContext]:
|
||||
def global_freezes(self) -> list[_GlobalFreezeContext]:
|
||||
"""Return all global Freezes."""
|
||||
return self._freezes
|
||||
|
||||
|
@ -459,12 +459,12 @@ class TimeoutManager:
|
|||
|
||||
def async_timeout(
|
||||
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
|
||||
) -> Union[_ZoneTaskContext, _GlobalTaskContext]:
|
||||
) -> _ZoneTaskContext | _GlobalTaskContext:
|
||||
"""Timeout based on a zone.
|
||||
|
||||
For using as Async Context Manager.
|
||||
"""
|
||||
current_task: Optional[asyncio.Task[Any]] = asyncio.current_task()
|
||||
current_task: asyncio.Task[Any] | None = asyncio.current_task()
|
||||
assert current_task
|
||||
|
||||
# Global Zone
|
||||
|
@ -483,7 +483,7 @@ class TimeoutManager:
|
|||
|
||||
def async_freeze(
|
||||
self, zone_name: str = ZONE_GLOBAL
|
||||
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
|
||||
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
||||
"""Freeze all timer until job is done.
|
||||
|
||||
For using as Async Context Manager.
|
||||
|
@ -502,7 +502,7 @@ class TimeoutManager:
|
|||
|
||||
def freeze(
|
||||
self, zone_name: str = ZONE_GLOBAL
|
||||
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
|
||||
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
||||
"""Freeze all timer until job is done.
|
||||
|
||||
For using as Context Manager.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Unit system helper class and methods."""
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Dict, Optional
|
||||
|
||||
from homeassistant.const import (
|
||||
CONF_UNIT_SYSTEM_IMPERIAL,
|
||||
|
@ -109,7 +110,7 @@ class UnitSystem:
|
|||
|
||||
return temperature_util.convert(temperature, from_unit, self.temperature_unit)
|
||||
|
||||
def length(self, length: Optional[float], from_unit: str) -> float:
|
||||
def length(self, length: float | None, from_unit: str) -> float:
|
||||
"""Convert the given length to this unit system."""
|
||||
if not isinstance(length, Number):
|
||||
raise TypeError(f"{length!s} is not a numeric value.")
|
||||
|
@ -119,7 +120,7 @@ class UnitSystem:
|
|||
length, from_unit, self.length_unit
|
||||
)
|
||||
|
||||
def pressure(self, pressure: Optional[float], from_unit: str) -> float:
|
||||
def pressure(self, pressure: float | None, from_unit: str) -> float:
|
||||
"""Convert the given pressure to this unit system."""
|
||||
if not isinstance(pressure, Number):
|
||||
raise TypeError(f"{pressure!s} is not a numeric value.")
|
||||
|
@ -129,7 +130,7 @@ class UnitSystem:
|
|||
pressure, from_unit, self.pressure_unit
|
||||
)
|
||||
|
||||
def volume(self, volume: Optional[float], from_unit: str) -> float:
|
||||
def volume(self, volume: float | None, from_unit: str) -> float:
|
||||
"""Convert the given volume to this unit system."""
|
||||
if not isinstance(volume, Number):
|
||||
raise TypeError(f"{volume!s} is not a numeric value.")
|
||||
|
@ -137,7 +138,7 @@ class UnitSystem:
|
|||
# type ignore: https://github.com/python/mypy/issues/7207
|
||||
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
|
||||
|
||||
def as_dict(self) -> Dict[str, str]:
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
"""Convert the unit system to a dictionary."""
|
||||
return {
|
||||
LENGTH: self.length_unit,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Deal with YAML input."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any
|
||||
|
||||
from .objects import Input
|
||||
|
||||
|
@ -14,14 +15,14 @@ class UndefinedSubstitution(Exception):
|
|||
self.input = input
|
||||
|
||||
|
||||
def extract_inputs(obj: Any) -> Set[str]:
|
||||
def extract_inputs(obj: Any) -> set[str]:
|
||||
"""Extract input from a structure."""
|
||||
found: Set[str] = set()
|
||||
found: set[str] = set()
|
||||
_extract_inputs(obj, found)
|
||||
return found
|
||||
|
||||
|
||||
def _extract_inputs(obj: Any, found: Set[str]) -> None:
|
||||
def _extract_inputs(obj: Any, found: set[str]) -> None:
|
||||
"""Extract input from a structure."""
|
||||
if isinstance(obj, Input):
|
||||
found.add(obj.name)
|
||||
|
@ -38,7 +39,7 @@ def _extract_inputs(obj: Any, found: Set[str]) -> None:
|
|||
return
|
||||
|
||||
|
||||
def substitute(obj: Any, substitutions: Dict[str, Any]) -> Any:
|
||||
def substitute(obj: Any, substitutions: dict[str, Any]) -> Any:
|
||||
"""Substitute values."""
|
||||
if isinstance(obj, Input):
|
||||
if obj.name not in substitutions:
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
"""Custom loader."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union, overload
|
||||
from typing import Any, Dict, Iterator, List, TextIO, TypeVar, Union, overload
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -27,7 +29,7 @@ class Secrets:
|
|||
def __init__(self, config_dir: Path):
|
||||
"""Initialize secrets."""
|
||||
self.config_dir = config_dir
|
||||
self._cache: Dict[Path, Dict[str, str]] = {}
|
||||
self._cache: dict[Path, dict[str, str]] = {}
|
||||
|
||||
def get(self, requester_path: str, secret: str) -> str:
|
||||
"""Return the value of a secret."""
|
||||
|
@ -55,7 +57,7 @@ class Secrets:
|
|||
|
||||
raise HomeAssistantError(f"Secret {secret} not defined")
|
||||
|
||||
def _load_secret_yaml(self, secret_dir: Path) -> Dict[str, str]:
|
||||
def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
|
||||
"""Load the secrets yaml from path."""
|
||||
secret_path = secret_dir / SECRET_YAML
|
||||
|
||||
|
@ -90,7 +92,7 @@ class Secrets:
|
|||
class SafeLineLoader(yaml.SafeLoader):
|
||||
"""Loader class that keeps track of line numbers."""
|
||||
|
||||
def __init__(self, stream: Any, secrets: Optional[Secrets] = None) -> None:
|
||||
def __init__(self, stream: Any, secrets: Secrets | None = None) -> None:
|
||||
"""Initialize a safe line loader."""
|
||||
super().__init__(stream)
|
||||
self.secrets = secrets
|
||||
|
@ -103,7 +105,7 @@ class SafeLineLoader(yaml.SafeLoader):
|
|||
return node
|
||||
|
||||
|
||||
def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
|
||||
def load_yaml(fname: str, secrets: Secrets | None = None) -> JSON_TYPE:
|
||||
"""Load a YAML file."""
|
||||
try:
|
||||
with open(fname, encoding="utf-8") as conf_file:
|
||||
|
@ -113,9 +115,7 @@ def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
|
|||
raise HomeAssistantError(exc) from exc
|
||||
|
||||
|
||||
def parse_yaml(
|
||||
content: Union[str, TextIO], secrets: Optional[Secrets] = None
|
||||
) -> JSON_TYPE:
|
||||
def parse_yaml(content: str | TextIO, secrets: Secrets | None = None) -> JSON_TYPE:
|
||||
"""Load a YAML file."""
|
||||
try:
|
||||
# If configuration file is empty YAML returns None
|
||||
|
@ -131,14 +131,14 @@ def parse_yaml(
|
|||
|
||||
@overload
|
||||
def _add_reference(
|
||||
obj: Union[list, NodeListClass], loader: SafeLineLoader, node: yaml.nodes.Node
|
||||
obj: list | NodeListClass, loader: SafeLineLoader, node: yaml.nodes.Node
|
||||
) -> NodeListClass:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def _add_reference(
|
||||
obj: Union[str, NodeStrClass], loader: SafeLineLoader, node: yaml.nodes.Node
|
||||
obj: str | NodeStrClass, loader: SafeLineLoader, node: yaml.nodes.Node
|
||||
) -> NodeStrClass:
|
||||
...
|
||||
|
||||
|
@ -223,7 +223,7 @@ def _include_dir_merge_named_yaml(
|
|||
|
||||
def _include_dir_list_yaml(
|
||||
loader: SafeLineLoader, node: yaml.nodes.Node
|
||||
) -> List[JSON_TYPE]:
|
||||
) -> list[JSON_TYPE]:
|
||||
"""Load multiple files from directory as a list."""
|
||||
loc = os.path.join(os.path.dirname(loader.name), node.value)
|
||||
return [
|
||||
|
@ -238,7 +238,7 @@ def _include_dir_merge_list_yaml(
|
|||
) -> JSON_TYPE:
|
||||
"""Load multiple files from directory as a merged list."""
|
||||
loc: str = os.path.join(os.path.dirname(loader.name), node.value)
|
||||
merged_list: List[JSON_TYPE] = []
|
||||
merged_list: list[JSON_TYPE] = []
|
||||
for fname in _find_files(loc, "*.yaml"):
|
||||
if os.path.basename(fname) == SECRET_YAML:
|
||||
continue
|
||||
|
@ -253,7 +253,7 @@ def _ordered_dict(loader: SafeLineLoader, node: yaml.nodes.MappingNode) -> Order
|
|||
loader.flatten_mapping(node)
|
||||
nodes = loader.construct_pairs(node)
|
||||
|
||||
seen: Dict = {}
|
||||
seen: dict = {}
|
||||
for (key, _), (child_node, _) in zip(nodes, node.value):
|
||||
line = child_node.start_mark.line
|
||||
|
||||
|
|
Loading…
Reference in New Issue