Update typing 03 (#48015)
parent
6fb2e63e49
commit
fabd73f08b
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
from typing import Any, Dict, Optional, Tuple, cast
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
@ -36,8 +36,8 @@ class InvalidProvider(Exception):
|
||||||
|
|
||||||
async def auth_manager_from_config(
|
async def auth_manager_from_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
provider_configs: List[Dict[str, Any]],
|
provider_configs: list[dict[str, Any]],
|
||||||
module_configs: List[Dict[str, Any]],
|
module_configs: list[dict[str, Any]],
|
||||||
) -> AuthManager:
|
) -> AuthManager:
|
||||||
"""Initialize an auth manager from config.
|
"""Initialize an auth manager from config.
|
||||||
|
|
||||||
|
@ -87,8 +87,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||||
self,
|
self,
|
||||||
handler_key: Any,
|
handler_key: Any,
|
||||||
*,
|
*,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: dict[str, Any] | None = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> data_entry_flow.FlowHandler:
|
) -> data_entry_flow.FlowHandler:
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||||
|
@ -97,8 +97,8 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
||||||
return await auth_provider.async_login_flow(context)
|
return await auth_provider.async_login_flow(context)
|
||||||
|
|
||||||
async def async_finish_flow(
|
async def async_finish_flow(
|
||||||
self, flow: data_entry_flow.FlowHandler, result: Dict[str, Any]
|
self, flow: data_entry_flow.FlowHandler, result: dict[str, Any]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return a user as result of login flow."""
|
"""Return a user as result of login flow."""
|
||||||
flow = cast(LoginFlow, flow)
|
flow = cast(LoginFlow, flow)
|
||||||
|
|
||||||
|
@ -157,22 +157,22 @@ class AuthManager:
|
||||||
self.login_flow = AuthManagerFlowManager(hass, self)
|
self.login_flow = AuthManagerFlowManager(hass, self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_providers(self) -> List[AuthProvider]:
|
def auth_providers(self) -> list[AuthProvider]:
|
||||||
"""Return a list of available auth providers."""
|
"""Return a list of available auth providers."""
|
||||||
return list(self._providers.values())
|
return list(self._providers.values())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_mfa_modules(self) -> List[MultiFactorAuthModule]:
|
def auth_mfa_modules(self) -> list[MultiFactorAuthModule]:
|
||||||
"""Return a list of available auth modules."""
|
"""Return a list of available auth modules."""
|
||||||
return list(self._mfa_modules.values())
|
return list(self._mfa_modules.values())
|
||||||
|
|
||||||
def get_auth_provider(
|
def get_auth_provider(
|
||||||
self, provider_type: str, provider_id: Optional[str]
|
self, provider_type: str, provider_id: str | None
|
||||||
) -> Optional[AuthProvider]:
|
) -> AuthProvider | None:
|
||||||
"""Return an auth provider, None if not found."""
|
"""Return an auth provider, None if not found."""
|
||||||
return self._providers.get((provider_type, provider_id))
|
return self._providers.get((provider_type, provider_id))
|
||||||
|
|
||||||
def get_auth_providers(self, provider_type: str) -> List[AuthProvider]:
|
def get_auth_providers(self, provider_type: str) -> list[AuthProvider]:
|
||||||
"""Return a List of auth provider of one type, Empty if not found."""
|
"""Return a List of auth provider of one type, Empty if not found."""
|
||||||
return [
|
return [
|
||||||
provider
|
provider
|
||||||
|
@ -180,30 +180,30 @@ class AuthManager:
|
||||||
if p_type == provider_type
|
if p_type == provider_type
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_auth_mfa_module(self, module_id: str) -> Optional[MultiFactorAuthModule]:
|
def get_auth_mfa_module(self, module_id: str) -> MultiFactorAuthModule | None:
|
||||||
"""Return a multi-factor auth module, None if not found."""
|
"""Return a multi-factor auth module, None if not found."""
|
||||||
return self._mfa_modules.get(module_id)
|
return self._mfa_modules.get(module_id)
|
||||||
|
|
||||||
async def async_get_users(self) -> List[models.User]:
|
async def async_get_users(self) -> list[models.User]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
return await self._store.async_get_users()
|
return await self._store.async_get_users()
|
||||||
|
|
||||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
async def async_get_user(self, user_id: str) -> models.User | None:
|
||||||
"""Retrieve a user."""
|
"""Retrieve a user."""
|
||||||
return await self._store.async_get_user(user_id)
|
return await self._store.async_get_user(user_id)
|
||||||
|
|
||||||
async def async_get_owner(self) -> Optional[models.User]:
|
async def async_get_owner(self) -> models.User | None:
|
||||||
"""Retrieve the owner."""
|
"""Retrieve the owner."""
|
||||||
users = await self.async_get_users()
|
users = await self.async_get_users()
|
||||||
return next((user for user in users if user.is_owner), None)
|
return next((user for user in users if user.is_owner), None)
|
||||||
|
|
||||||
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
async def async_get_group(self, group_id: str) -> models.Group | None:
|
||||||
"""Retrieve all groups."""
|
"""Retrieve all groups."""
|
||||||
return await self._store.async_get_group(group_id)
|
return await self._store.async_get_group(group_id)
|
||||||
|
|
||||||
async def async_get_user_by_credentials(
|
async def async_get_user_by_credentials(
|
||||||
self, credentials: models.Credentials
|
self, credentials: models.Credentials
|
||||||
) -> Optional[models.User]:
|
) -> models.User | None:
|
||||||
"""Get a user by credential, return None if not found."""
|
"""Get a user by credential, return None if not found."""
|
||||||
for user in await self.async_get_users():
|
for user in await self.async_get_users():
|
||||||
for creds in user.credentials:
|
for creds in user.credentials:
|
||||||
|
@ -213,7 +213,7 @@ class AuthManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_create_system_user(
|
async def async_create_system_user(
|
||||||
self, name: str, group_ids: Optional[List[str]] = None
|
self, name: str, group_ids: list[str] | None = None
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a system user."""
|
"""Create a system user."""
|
||||||
user = await self._store.async_create_user(
|
user = await self._store.async_create_user(
|
||||||
|
@ -225,10 +225,10 @@ class AuthManager:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def async_create_user(
|
async def async_create_user(
|
||||||
self, name: str, group_ids: Optional[List[str]] = None
|
self, name: str, group_ids: list[str] | None = None
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"is_active": True,
|
"is_active": True,
|
||||||
"group_ids": group_ids or [],
|
"group_ids": group_ids or [],
|
||||||
|
@ -294,12 +294,12 @@ class AuthManager:
|
||||||
async def async_update_user(
|
async def async_update_user(
|
||||||
self,
|
self,
|
||||||
user: models.User,
|
user: models.User,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
is_active: Optional[bool] = None,
|
is_active: bool | None = None,
|
||||||
group_ids: Optional[List[str]] = None,
|
group_ids: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user."""
|
"""Update a user."""
|
||||||
kwargs: Dict[str, Any] = {}
|
kwargs: dict[str, Any] = {}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
kwargs["name"] = name
|
kwargs["name"] = name
|
||||||
if group_ids is not None:
|
if group_ids is not None:
|
||||||
|
@ -362,9 +362,9 @@ class AuthManager:
|
||||||
|
|
||||||
await module.async_depose_user(user.id)
|
await module.async_depose_user(user.id)
|
||||||
|
|
||||||
async def async_get_enabled_mfa(self, user: models.User) -> Dict[str, str]:
|
async def async_get_enabled_mfa(self, user: models.User) -> dict[str, str]:
|
||||||
"""List enabled mfa modules for user."""
|
"""List enabled mfa modules for user."""
|
||||||
modules: Dict[str, str] = OrderedDict()
|
modules: dict[str, str] = OrderedDict()
|
||||||
for module_id, module in self._mfa_modules.items():
|
for module_id, module in self._mfa_modules.items():
|
||||||
if await module.async_is_user_setup(user.id):
|
if await module.async_is_user_setup(user.id):
|
||||||
modules[module_id] = module.name
|
modules[module_id] = module.name
|
||||||
|
@ -373,12 +373,12 @@ class AuthManager:
|
||||||
async def async_create_refresh_token(
|
async def async_create_refresh_token(
|
||||||
self,
|
self,
|
||||||
user: models.User,
|
user: models.User,
|
||||||
client_id: Optional[str] = None,
|
client_id: str | None = None,
|
||||||
client_name: Optional[str] = None,
|
client_name: str | None = None,
|
||||||
client_icon: Optional[str] = None,
|
client_icon: str | None = None,
|
||||||
token_type: Optional[str] = None,
|
token_type: str | None = None,
|
||||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||||
credential: Optional[models.Credentials] = None,
|
credential: models.Credentials | None = None,
|
||||||
) -> models.RefreshToken:
|
) -> models.RefreshToken:
|
||||||
"""Create a new refresh token for a user."""
|
"""Create a new refresh token for a user."""
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
|
@ -432,13 +432,13 @@ class AuthManager:
|
||||||
|
|
||||||
async def async_get_refresh_token(
|
async def async_get_refresh_token(
|
||||||
self, token_id: str
|
self, token_id: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
return await self._store.async_get_refresh_token(token_id)
|
return await self._store.async_get_refresh_token(token_id)
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(
|
async def async_get_refresh_token_by_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
return await self._store.async_get_refresh_token_by_token(token)
|
return await self._store.async_get_refresh_token_by_token(token)
|
||||||
|
|
||||||
|
@ -450,7 +450,7 @@ class AuthManager:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_access_token(
|
def async_create_access_token(
|
||||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new access token."""
|
"""Create a new access token."""
|
||||||
self.async_validate_refresh_token(refresh_token, remote_ip)
|
self.async_validate_refresh_token(refresh_token, remote_ip)
|
||||||
|
@ -471,7 +471,7 @@ class AuthManager:
|
||||||
@callback
|
@callback
|
||||||
def _async_resolve_provider(
|
def _async_resolve_provider(
|
||||||
self, refresh_token: models.RefreshToken
|
self, refresh_token: models.RefreshToken
|
||||||
) -> Optional[AuthProvider]:
|
) -> AuthProvider | None:
|
||||||
"""Get the auth provider for the given refresh token.
|
"""Get the auth provider for the given refresh token.
|
||||||
|
|
||||||
Raises an exception if the expected provider is no longer available or return
|
Raises an exception if the expected provider is no longer available or return
|
||||||
|
@ -492,7 +492,7 @@ class AuthManager:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_refresh_token(
|
def async_validate_refresh_token(
|
||||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Validate that a refresh token is usable.
|
"""Validate that a refresh token is usable.
|
||||||
|
|
||||||
|
@ -504,7 +504,7 @@ class AuthManager:
|
||||||
|
|
||||||
async def async_validate_access_token(
|
async def async_validate_access_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> models.RefreshToken | None:
|
||||||
"""Return refresh token if an access token is valid."""
|
"""Return refresh token if an access token is valid."""
|
||||||
try:
|
try:
|
||||||
unverif_claims = jwt.decode(token, verify=False)
|
unverif_claims = jwt.decode(token, verify=False)
|
||||||
|
@ -535,7 +535,7 @@ class AuthManager:
|
||||||
@callback
|
@callback
|
||||||
def _async_get_auth_provider(
|
def _async_get_auth_provider(
|
||||||
self, credentials: models.Credentials
|
self, credentials: models.Credentials
|
||||||
) -> Optional[AuthProvider]:
|
) -> AuthProvider | None:
|
||||||
"""Get auth provider from a set of credentials."""
|
"""Get auth provider from a set of credentials."""
|
||||||
auth_provider_key = (
|
auth_provider_key = (
|
||||||
credentials.auth_provider_type,
|
credentials.auth_provider_type,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
"""Storage for auth models."""
|
"""Storage for auth models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import hmac
|
import hmac
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
|
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
@ -34,15 +36,15 @@ class AuthStore:
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
"""Initialize the auth store."""
|
"""Initialize the auth store."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._users: Optional[Dict[str, models.User]] = None
|
self._users: dict[str, models.User] | None = None
|
||||||
self._groups: Optional[Dict[str, models.Group]] = None
|
self._groups: dict[str, models.Group] | None = None
|
||||||
self._perm_lookup: Optional[PermissionLookup] = None
|
self._perm_lookup: PermissionLookup | None = None
|
||||||
self._store = hass.helpers.storage.Store(
|
self._store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||||
)
|
)
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def async_get_groups(self) -> List[models.Group]:
|
async def async_get_groups(self) -> list[models.Group]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._groups is None:
|
if self._groups is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -50,7 +52,7 @@ class AuthStore:
|
||||||
|
|
||||||
return list(self._groups.values())
|
return list(self._groups.values())
|
||||||
|
|
||||||
async def async_get_group(self, group_id: str) -> Optional[models.Group]:
|
async def async_get_group(self, group_id: str) -> models.Group | None:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._groups is None:
|
if self._groups is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -58,7 +60,7 @@ class AuthStore:
|
||||||
|
|
||||||
return self._groups.get(group_id)
|
return self._groups.get(group_id)
|
||||||
|
|
||||||
async def async_get_users(self) -> List[models.User]:
|
async def async_get_users(self) -> list[models.User]:
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -66,7 +68,7 @@ class AuthStore:
|
||||||
|
|
||||||
return list(self._users.values())
|
return list(self._users.values())
|
||||||
|
|
||||||
async def async_get_user(self, user_id: str) -> Optional[models.User]:
|
async def async_get_user(self, user_id: str) -> models.User | None:
|
||||||
"""Retrieve a user by id."""
|
"""Retrieve a user by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -76,12 +78,12 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_create_user(
|
async def async_create_user(
|
||||||
self,
|
self,
|
||||||
name: Optional[str],
|
name: str | None,
|
||||||
is_owner: Optional[bool] = None,
|
is_owner: bool | None = None,
|
||||||
is_active: Optional[bool] = None,
|
is_active: bool | None = None,
|
||||||
system_generated: Optional[bool] = None,
|
system_generated: bool | None = None,
|
||||||
credentials: Optional[models.Credentials] = None,
|
credentials: models.Credentials | None = None,
|
||||||
group_ids: Optional[List[str]] = None,
|
group_ids: list[str] | None = None,
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a new user."""
|
"""Create a new user."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
|
@ -97,7 +99,7 @@ class AuthStore:
|
||||||
raise ValueError(f"Invalid group specified {group_id}")
|
raise ValueError(f"Invalid group specified {group_id}")
|
||||||
groups.append(group)
|
groups.append(group)
|
||||||
|
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"name": name,
|
"name": name,
|
||||||
# Until we get group management, we just put everyone in the
|
# Until we get group management, we just put everyone in the
|
||||||
# same group.
|
# same group.
|
||||||
|
@ -146,9 +148,9 @@ class AuthStore:
|
||||||
async def async_update_user(
|
async def async_update_user(
|
||||||
self,
|
self,
|
||||||
user: models.User,
|
user: models.User,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
is_active: Optional[bool] = None,
|
is_active: bool | None = None,
|
||||||
group_ids: Optional[List[str]] = None,
|
group_ids: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user."""
|
"""Update a user."""
|
||||||
assert self._groups is not None
|
assert self._groups is not None
|
||||||
|
@ -203,15 +205,15 @@ class AuthStore:
|
||||||
async def async_create_refresh_token(
|
async def async_create_refresh_token(
|
||||||
self,
|
self,
|
||||||
user: models.User,
|
user: models.User,
|
||||||
client_id: Optional[str] = None,
|
client_id: str | None = None,
|
||||||
client_name: Optional[str] = None,
|
client_name: str | None = None,
|
||||||
client_icon: Optional[str] = None,
|
client_icon: str | None = None,
|
||||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||||
credential: Optional[models.Credentials] = None,
|
credential: models.Credentials | None = None,
|
||||||
) -> models.RefreshToken:
|
) -> models.RefreshToken:
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"user": user,
|
"user": user,
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"token_type": token_type,
|
"token_type": token_type,
|
||||||
|
@ -244,7 +246,7 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_get_refresh_token(
|
async def async_get_refresh_token(
|
||||||
self, token_id: str
|
self, token_id: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by id."""
|
"""Get refresh token by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -259,7 +261,7 @@ class AuthStore:
|
||||||
|
|
||||||
async def async_get_refresh_token_by_token(
|
async def async_get_refresh_token_by_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> models.RefreshToken | None:
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -276,7 +278,7 @@ class AuthStore:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_log_refresh_token_usage(
|
def async_log_refresh_token_usage(
|
||||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update refresh token last used information."""
|
"""Update refresh token last used information."""
|
||||||
refresh_token.last_used_at = dt_util.utcnow()
|
refresh_token.last_used_at = dt_util.utcnow()
|
||||||
|
@ -309,9 +311,9 @@ class AuthStore:
|
||||||
self._set_defaults()
|
self._set_defaults()
|
||||||
return
|
return
|
||||||
|
|
||||||
users: Dict[str, models.User] = OrderedDict()
|
users: dict[str, models.User] = OrderedDict()
|
||||||
groups: Dict[str, models.Group] = OrderedDict()
|
groups: dict[str, models.Group] = OrderedDict()
|
||||||
credentials: Dict[str, models.Credentials] = OrderedDict()
|
credentials: dict[str, models.Credentials] = OrderedDict()
|
||||||
|
|
||||||
# Soft-migrating data as we load. We are going to make sure we have a
|
# Soft-migrating data as we load. We are going to make sure we have a
|
||||||
# read only group and an admin group. There are two states that we can
|
# read only group and an admin group. There are two states that we can
|
||||||
|
@ -328,7 +330,7 @@ class AuthStore:
|
||||||
# was added.
|
# was added.
|
||||||
|
|
||||||
for group_dict in data.get("groups", []):
|
for group_dict in data.get("groups", []):
|
||||||
policy: Optional[PolicyType] = None
|
policy: PolicyType | None = None
|
||||||
|
|
||||||
if group_dict["id"] == GROUP_ID_ADMIN:
|
if group_dict["id"] == GROUP_ID_ADMIN:
|
||||||
has_admin_group = True
|
has_admin_group = True
|
||||||
|
@ -489,7 +491,7 @@ class AuthStore:
|
||||||
self._store.async_delay_save(self._data_to_save, 1)
|
self._store.async_delay_save(self._data_to_save, 1)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _data_to_save(self) -> Dict:
|
def _data_to_save(self) -> dict:
|
||||||
"""Return the data to store."""
|
"""Return the data to store."""
|
||||||
assert self._users is not None
|
assert self._users is not None
|
||||||
assert self._groups is not None
|
assert self._groups is not None
|
||||||
|
@ -508,7 +510,7 @@ class AuthStore:
|
||||||
|
|
||||||
groups = []
|
groups = []
|
||||||
for group in self._groups.values():
|
for group in self._groups.values():
|
||||||
g_dict: Dict[str, Any] = {
|
g_dict: dict[str, Any] = {
|
||||||
"id": group.id,
|
"id": group.id,
|
||||||
# Name not read for sys groups. Kept here for backwards compat
|
# Name not read for sys groups. Kept here for backwards compat
|
||||||
"name": group.name,
|
"name": group.name,
|
||||||
|
@ -567,7 +569,7 @@ class AuthStore:
|
||||||
"""Set default values for auth store."""
|
"""Set default values for auth store."""
|
||||||
self._users = OrderedDict()
|
self._users = OrderedDict()
|
||||||
|
|
||||||
groups: Dict[str, models.Group] = OrderedDict()
|
groups: dict[str, models.Group] = OrderedDict()
|
||||||
admin_group = _system_admin_group()
|
admin_group = _system_admin_group()
|
||||||
groups[admin_group.id] = admin_group
|
groups[admin_group.id] = admin_group
|
||||||
user_group = _system_user_group()
|
user_group = _system_user_group()
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
@ -38,7 +38,7 @@ class MultiFactorAuthModule:
|
||||||
DEFAULT_TITLE = "Unnamed auth module"
|
DEFAULT_TITLE = "Unnamed auth module"
|
||||||
MAX_RETRY_TIME = 3
|
MAX_RETRY_TIME = 3
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||||
"""Initialize an auth module."""
|
"""Initialize an auth module."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -87,7 +87,7 @@ class MultiFactorAuthModule:
|
||||||
"""Return whether user is setup."""
|
"""Return whether user is setup."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||||
"""Return True if validation passed."""
|
"""Return True if validation passed."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -104,14 +104,14 @@ class SetupFlow(data_entry_flow.FlowHandler):
|
||||||
self._user_id = user_id
|
self._user_id = user_id
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the first step of setup flow.
|
"""Handle the first step of setup flow.
|
||||||
|
|
||||||
Return self.async_show_form(step_id='init') if user_input is None.
|
Return self.async_show_form(step_id='init') if user_input is None.
|
||||||
Return self.async_create_entry(data={'result': result}) if finish.
|
Return self.async_create_entry(data={'result': result}) if finish.
|
||||||
"""
|
"""
|
||||||
errors: Dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
if user_input:
|
if user_input:
|
||||||
result = await self._auth_module.async_setup_user(self._user_id, user_input)
|
result = await self._auth_module.async_setup_user(self._user_id, user_input)
|
||||||
|
@ -125,7 +125,7 @@ class SetupFlow(data_entry_flow.FlowHandler):
|
||||||
|
|
||||||
|
|
||||||
async def auth_mfa_module_from_config(
|
async def auth_mfa_module_from_config(
|
||||||
hass: HomeAssistant, config: Dict[str, Any]
|
hass: HomeAssistant, config: dict[str, Any]
|
||||||
) -> MultiFactorAuthModule:
|
) -> MultiFactorAuthModule:
|
||||||
"""Initialize an auth module from a config."""
|
"""Initialize an auth module from a config."""
|
||||||
module_name = config[CONF_TYPE]
|
module_name = config[CONF_TYPE]
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
"""Example auth module."""
|
"""Example auth module."""
|
||||||
from typing import Any, Dict
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -28,7 +30,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
|
||||||
|
|
||||||
DEFAULT_TITLE = "Insecure Personal Identify Number"
|
DEFAULT_TITLE = "Insecure Personal Identify Number"
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
super().__init__(hass, config)
|
super().__init__(hass, config)
|
||||||
self._data = config["data"]
|
self._data = config["data"]
|
||||||
|
@ -80,7 +82,7 @@ class InsecureExampleModule(MultiFactorAuthModule):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||||
"""Return True if validation passed."""
|
"""Return True if validation passed."""
|
||||||
for data in self._data:
|
for data in self._data:
|
||||||
if data["user_id"] == user_id:
|
if data["user_id"] == user_id:
|
||||||
|
|
|
@ -2,10 +2,12 @@
|
||||||
|
|
||||||
Sending HOTP through notify service
|
Sending HOTP through notify service
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -79,8 +81,8 @@ class NotifySetting:
|
||||||
|
|
||||||
secret: str = attr.ib(factory=_generate_secret) # not persistent
|
secret: str = attr.ib(factory=_generate_secret) # not persistent
|
||||||
counter: int = attr.ib(factory=_generate_random) # not persistent
|
counter: int = attr.ib(factory=_generate_random) # not persistent
|
||||||
notify_service: Optional[str] = attr.ib(default=None)
|
notify_service: str | None = attr.ib(default=None)
|
||||||
target: Optional[str] = attr.ib(default=None)
|
target: str | None = attr.ib(default=None)
|
||||||
|
|
||||||
|
|
||||||
_UsersDict = Dict[str, NotifySetting]
|
_UsersDict = Dict[str, NotifySetting]
|
||||||
|
@ -92,10 +94,10 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||||
|
|
||||||
DEFAULT_TITLE = "Notify One-Time Password"
|
DEFAULT_TITLE = "Notify One-Time Password"
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
super().__init__(hass, config)
|
super().__init__(hass, config)
|
||||||
self._user_settings: Optional[_UsersDict] = None
|
self._user_settings: _UsersDict | None = None
|
||||||
self._user_store = hass.helpers.storage.Store(
|
self._user_store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||||
)
|
)
|
||||||
|
@ -146,7 +148,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def aync_get_available_notify_services(self) -> List[str]:
|
def aync_get_available_notify_services(self) -> list[str]:
|
||||||
"""Return list of notify services."""
|
"""Return list of notify services."""
|
||||||
unordered_services = set()
|
unordered_services = set()
|
||||||
|
|
||||||
|
@ -198,7 +200,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||||
|
|
||||||
return user_id in self._user_settings
|
return user_id in self._user_settings
|
||||||
|
|
||||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||||
"""Return True if validation passed."""
|
"""Return True if validation passed."""
|
||||||
if self._user_settings is None:
|
if self._user_settings is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -258,7 +260,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_notify(
|
async def async_notify(
|
||||||
self, code: str, notify_service: str, target: Optional[str] = None
|
self, code: str, notify_service: str, target: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send code by notify service."""
|
"""Send code by notify service."""
|
||||||
data = {"message": self._message_template.format(code)}
|
data = {"message": self._message_template.format(code)}
|
||||||
|
@ -276,23 +278,23 @@ class NotifySetupFlow(SetupFlow):
|
||||||
auth_module: NotifyAuthModule,
|
auth_module: NotifyAuthModule,
|
||||||
setup_schema: vol.Schema,
|
setup_schema: vol.Schema,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
available_notify_services: List[str],
|
available_notify_services: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the setup flow."""
|
"""Initialize the setup flow."""
|
||||||
super().__init__(auth_module, setup_schema, user_id)
|
super().__init__(auth_module, setup_schema, user_id)
|
||||||
# to fix typing complaint
|
# to fix typing complaint
|
||||||
self._auth_module: NotifyAuthModule = auth_module
|
self._auth_module: NotifyAuthModule = auth_module
|
||||||
self._available_notify_services = available_notify_services
|
self._available_notify_services = available_notify_services
|
||||||
self._secret: Optional[str] = None
|
self._secret: str | None = None
|
||||||
self._count: Optional[int] = None
|
self._count: int | None = None
|
||||||
self._notify_service: Optional[str] = None
|
self._notify_service: str | None = None
|
||||||
self._target: Optional[str] = None
|
self._target: str | None = None
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Let user select available notify services."""
|
"""Let user select available notify services."""
|
||||||
errors: Dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
hass = self._auth_module.hass
|
hass = self._auth_module.hass
|
||||||
if user_input:
|
if user_input:
|
||||||
|
@ -306,7 +308,7 @@ class NotifySetupFlow(SetupFlow):
|
||||||
if not self._available_notify_services:
|
if not self._available_notify_services:
|
||||||
return self.async_abort(reason="no_available_service")
|
return self.async_abort(reason="no_available_service")
|
||||||
|
|
||||||
schema: Dict[str, Any] = OrderedDict()
|
schema: dict[str, Any] = OrderedDict()
|
||||||
schema["notify_service"] = vol.In(self._available_notify_services)
|
schema["notify_service"] = vol.In(self._available_notify_services)
|
||||||
schema["target"] = vol.Optional(str)
|
schema["target"] = vol.Optional(str)
|
||||||
|
|
||||||
|
@ -315,10 +317,10 @@ class NotifySetupFlow(SetupFlow):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_setup(
|
async def async_step_setup(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Verify user can receive one-time password."""
|
"""Verify user can receive one-time password."""
|
||||||
errors: Dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
hass = self._auth_module.hass
|
hass = self._auth_module.hass
|
||||||
if user_input:
|
if user_input:
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Time-based One Time Password auth module."""
|
"""Time-based One Time Password auth module."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -50,7 +52,7 @@ def _generate_qr_code(data: str) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_secret_and_qr_code(username: str) -> Tuple[str, str, str]:
|
def _generate_secret_and_qr_code(username: str) -> tuple[str, str, str]:
|
||||||
"""Generate a secret, url, and QR code."""
|
"""Generate a secret, url, and QR code."""
|
||||||
import pyotp # pylint: disable=import-outside-toplevel
|
import pyotp # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -69,10 +71,10 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||||
DEFAULT_TITLE = "Time-based One Time Password"
|
DEFAULT_TITLE = "Time-based One Time Password"
|
||||||
MAX_RETRY_TIME = 5
|
MAX_RETRY_TIME = 5
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
|
def __init__(self, hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||||
"""Initialize the user data store."""
|
"""Initialize the user data store."""
|
||||||
super().__init__(hass, config)
|
super().__init__(hass, config)
|
||||||
self._users: Optional[Dict[str, str]] = None
|
self._users: dict[str, str] | None = None
|
||||||
self._user_store = hass.helpers.storage.Store(
|
self._user_store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||||
)
|
)
|
||||||
|
@ -100,7 +102,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||||
"""Save data."""
|
"""Save data."""
|
||||||
await self._user_store.async_save({STORAGE_USERS: self._users})
|
await self._user_store.async_save({STORAGE_USERS: self._users})
|
||||||
|
|
||||||
def _add_ota_secret(self, user_id: str, secret: Optional[str] = None) -> str:
|
def _add_ota_secret(self, user_id: str, secret: str | None = None) -> str:
|
||||||
"""Create a ota_secret for user."""
|
"""Create a ota_secret for user."""
|
||||||
import pyotp # pylint: disable=import-outside-toplevel
|
import pyotp # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
@ -145,7 +147,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||||
|
|
||||||
return user_id in self._users # type: ignore
|
return user_id in self._users # type: ignore
|
||||||
|
|
||||||
async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
|
async def async_validate(self, user_id: str, user_input: dict[str, Any]) -> bool:
|
||||||
"""Return True if validation passed."""
|
"""Return True if validation passed."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self._async_load()
|
await self._async_load()
|
||||||
|
@ -181,13 +183,13 @@ class TotpSetupFlow(SetupFlow):
|
||||||
# to fix typing complaint
|
# to fix typing complaint
|
||||||
self._auth_module: TotpAuthModule = auth_module
|
self._auth_module: TotpAuthModule = auth_module
|
||||||
self._user = user
|
self._user = user
|
||||||
self._ota_secret: Optional[str] = None
|
self._ota_secret: str | None = None
|
||||||
self._url = None # type Optional[str]
|
self._url = None # type Optional[str]
|
||||||
self._image = None # type Optional[str]
|
self._image = None # type Optional[str]
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the first step of setup flow.
|
"""Handle the first step of setup flow.
|
||||||
|
|
||||||
Return self.async_show_form(step_id='init') if user_input is None.
|
Return self.async_show_form(step_id='init') if user_input is None.
|
||||||
|
@ -195,7 +197,7 @@ class TotpSetupFlow(SetupFlow):
|
||||||
"""
|
"""
|
||||||
import pyotp # pylint: disable=import-outside-toplevel
|
import pyotp # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
errors: Dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
if user_input:
|
if user_input:
|
||||||
verified = await self.hass.async_add_executor_job(
|
verified = await self.hass.async_add_executor_job(
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Auth models."""
|
"""Auth models."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Dict, List, NamedTuple, Optional
|
from typing import NamedTuple
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -21,7 +23,7 @@ TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
|
||||||
class Group:
|
class Group:
|
||||||
"""A group."""
|
"""A group."""
|
||||||
|
|
||||||
name: Optional[str] = attr.ib()
|
name: str | None = attr.ib()
|
||||||
policy: perm_mdl.PolicyType = attr.ib()
|
policy: perm_mdl.PolicyType = attr.ib()
|
||||||
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
||||||
system_generated: bool = attr.ib(default=False)
|
system_generated: bool = attr.ib(default=False)
|
||||||
|
@ -31,24 +33,24 @@ class Group:
|
||||||
class User:
|
class User:
|
||||||
"""A user."""
|
"""A user."""
|
||||||
|
|
||||||
name: Optional[str] = attr.ib()
|
name: str | None = attr.ib()
|
||||||
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
|
perm_lookup: perm_mdl.PermissionLookup = attr.ib(eq=False, order=False)
|
||||||
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
id: str = attr.ib(factory=lambda: uuid.uuid4().hex)
|
||||||
is_owner: bool = attr.ib(default=False)
|
is_owner: bool = attr.ib(default=False)
|
||||||
is_active: bool = attr.ib(default=False)
|
is_active: bool = attr.ib(default=False)
|
||||||
system_generated: bool = attr.ib(default=False)
|
system_generated: bool = attr.ib(default=False)
|
||||||
|
|
||||||
groups: List[Group] = attr.ib(factory=list, eq=False, order=False)
|
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
|
||||||
|
|
||||||
# List of credentials of a user.
|
# List of credentials of a user.
|
||||||
credentials: List["Credentials"] = attr.ib(factory=list, eq=False, order=False)
|
credentials: list["Credentials"] = attr.ib(factory=list, eq=False, order=False)
|
||||||
|
|
||||||
# Tokens associated with a user.
|
# Tokens associated with a user.
|
||||||
refresh_tokens: Dict[str, "RefreshToken"] = attr.ib(
|
refresh_tokens: dict[str, "RefreshToken"] = attr.ib(
|
||||||
factory=dict, eq=False, order=False
|
factory=dict, eq=False, order=False
|
||||||
)
|
)
|
||||||
|
|
||||||
_permissions: Optional[perm_mdl.PolicyPermissions] = attr.ib(
|
_permissions: perm_mdl.PolicyPermissions | None = attr.ib(
|
||||||
init=False,
|
init=False,
|
||||||
eq=False,
|
eq=False,
|
||||||
order=False,
|
order=False,
|
||||||
|
@ -89,10 +91,10 @@ class RefreshToken:
|
||||||
"""RefreshToken for a user to grant new access tokens."""
|
"""RefreshToken for a user to grant new access tokens."""
|
||||||
|
|
||||||
user: User = attr.ib()
|
user: User = attr.ib()
|
||||||
client_id: Optional[str] = attr.ib()
|
client_id: str | None = attr.ib()
|
||||||
access_token_expiration: timedelta = attr.ib()
|
access_token_expiration: timedelta = attr.ib()
|
||||||
client_name: Optional[str] = attr.ib(default=None)
|
client_name: str | None = attr.ib(default=None)
|
||||||
client_icon: Optional[str] = attr.ib(default=None)
|
client_icon: str | None = attr.ib(default=None)
|
||||||
token_type: str = attr.ib(
|
token_type: str = attr.ib(
|
||||||
default=TOKEN_TYPE_NORMAL,
|
default=TOKEN_TYPE_NORMAL,
|
||||||
validator=attr.validators.in_(
|
validator=attr.validators.in_(
|
||||||
|
@ -104,12 +106,12 @@ class RefreshToken:
|
||||||
token: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
token: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
||||||
jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
jwt_key: str = attr.ib(factory=lambda: secrets.token_hex(64))
|
||||||
|
|
||||||
last_used_at: Optional[datetime] = attr.ib(default=None)
|
last_used_at: datetime | None = attr.ib(default=None)
|
||||||
last_used_ip: Optional[str] = attr.ib(default=None)
|
last_used_ip: str | None = attr.ib(default=None)
|
||||||
|
|
||||||
credential: Optional["Credentials"] = attr.ib(default=None)
|
credential: "Credentials" | None = attr.ib(default=None)
|
||||||
|
|
||||||
version: Optional[str] = attr.ib(default=__version__)
|
version: str | None = attr.ib(default=__version__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
|
@ -117,7 +119,7 @@ class Credentials:
|
||||||
"""Credentials for a user on an auth provider."""
|
"""Credentials for a user on an auth provider."""
|
||||||
|
|
||||||
auth_provider_type: str = attr.ib()
|
auth_provider_type: str = attr.ib()
|
||||||
auth_provider_id: Optional[str] = attr.ib()
|
auth_provider_id: str | None = attr.ib()
|
||||||
|
|
||||||
# Allow the auth provider to store data to represent their auth.
|
# Allow the auth provider to store data to represent their auth.
|
||||||
data: dict = attr.ib()
|
data: dict = attr.ib()
|
||||||
|
@ -129,5 +131,5 @@ class Credentials:
|
||||||
class UserMeta(NamedTuple):
|
class UserMeta(NamedTuple):
|
||||||
"""User metadata."""
|
"""User metadata."""
|
||||||
|
|
||||||
name: Optional[str]
|
name: str | None
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Permissions for Home Assistant."""
|
"""Permissions for Home Assistant."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -19,7 +21,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||||
class AbstractPermissions:
|
class AbstractPermissions:
|
||||||
"""Default permissions class."""
|
"""Default permissions class."""
|
||||||
|
|
||||||
_cached_entity_func: Optional[Callable[[str, str], bool]] = None
|
_cached_entity_func: Callable[[str, str], bool] | None = None
|
||||||
|
|
||||||
def _entity_func(self) -> Callable[[str, str], bool]:
|
def _entity_func(self) -> Callable[[str, str], bool]:
|
||||||
"""Return a function that can test entity access."""
|
"""Return a function that can test entity access."""
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Entity permissions."""
|
"""Entity permissions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Callable, Optional
|
from typing import Callable
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -43,14 +45,14 @@ ENTITY_POLICY_SCHEMA = vol.Any(
|
||||||
|
|
||||||
def _lookup_domain(
|
def _lookup_domain(
|
||||||
perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
|
perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
|
||||||
) -> Optional[ValueType]:
|
) -> ValueType | None:
|
||||||
"""Look up entity permissions by domain."""
|
"""Look up entity permissions by domain."""
|
||||||
return domains_dict.get(entity_id.split(".", 1)[0])
|
return domains_dict.get(entity_id.split(".", 1)[0])
|
||||||
|
|
||||||
|
|
||||||
def _lookup_area(
|
def _lookup_area(
|
||||||
perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
|
perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
|
||||||
) -> Optional[ValueType]:
|
) -> ValueType | None:
|
||||||
"""Look up entity permissions by area."""
|
"""Look up entity permissions by area."""
|
||||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||||
|
|
||||||
|
@ -67,7 +69,7 @@ def _lookup_area(
|
||||||
|
|
||||||
def _lookup_device(
|
def _lookup_device(
|
||||||
perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
|
perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
|
||||||
) -> Optional[ValueType]:
|
) -> ValueType | None:
|
||||||
"""Look up entity permissions by device."""
|
"""Look up entity permissions by device."""
|
||||||
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
entity_entry = perm_lookup.entity_registry.async_get(entity_id)
|
||||||
|
|
||||||
|
@ -79,7 +81,7 @@ def _lookup_device(
|
||||||
|
|
||||||
def _lookup_entity_id(
|
def _lookup_entity_id(
|
||||||
perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
|
perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
|
||||||
) -> Optional[ValueType]:
|
) -> ValueType | None:
|
||||||
"""Look up entity permission by entity id."""
|
"""Look up entity permission by entity id."""
|
||||||
return entities_dict.get(entity_id)
|
return entities_dict.get(entity_id)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
"""Merging of policies."""
|
"""Merging of policies."""
|
||||||
from typing import Dict, List, Set, cast
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from .types import CategoryType, PolicyType
|
from .types import CategoryType, PolicyType
|
||||||
|
|
||||||
|
|
||||||
def merge_policies(policies: List[PolicyType]) -> PolicyType:
|
def merge_policies(policies: list[PolicyType]) -> PolicyType:
|
||||||
"""Merge policies."""
|
"""Merge policies."""
|
||||||
new_policy: Dict[str, CategoryType] = {}
|
new_policy: dict[str, CategoryType] = {}
|
||||||
seen: Set[str] = set()
|
seen: set[str] = set()
|
||||||
for policy in policies:
|
for policy in policies:
|
||||||
for category in policy:
|
for category in policy:
|
||||||
if category in seen:
|
if category in seen:
|
||||||
|
@ -20,7 +22,7 @@ def merge_policies(policies: List[PolicyType]) -> PolicyType:
|
||||||
return new_policy
|
return new_policy
|
||||||
|
|
||||||
|
|
||||||
def _merge_policies(sources: List[CategoryType]) -> CategoryType:
|
def _merge_policies(sources: list[CategoryType]) -> CategoryType:
|
||||||
"""Merge a policy."""
|
"""Merge a policy."""
|
||||||
# When merging policies, the most permissive wins.
|
# When merging policies, the most permissive wins.
|
||||||
# This means we order it like this:
|
# This means we order it like this:
|
||||||
|
@ -34,7 +36,7 @@ def _merge_policies(sources: List[CategoryType]) -> CategoryType:
|
||||||
# merge each key in the source.
|
# merge each key in the source.
|
||||||
|
|
||||||
policy: CategoryType = None
|
policy: CategoryType = None
|
||||||
seen: Set[str] = set()
|
seen: set[str] = set()
|
||||||
for source in sources:
|
for source in sources:
|
||||||
if source is None:
|
if source is None:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Helpers to deal with permissions."""
|
"""Helpers to deal with permissions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, Dict, List, Optional, cast
|
from typing import Callable, Dict, Optional, cast
|
||||||
|
|
||||||
from .const import SUBCAT_ALL
|
from .const import SUBCAT_ALL
|
||||||
from .models import PermissionLookup
|
from .models import PermissionLookup
|
||||||
|
@ -45,7 +47,7 @@ def compile_policy(
|
||||||
|
|
||||||
assert isinstance(policy, dict)
|
assert isinstance(policy, dict)
|
||||||
|
|
||||||
funcs: List[Callable[[str, str], Optional[bool]]] = []
|
funcs: list[Callable[[str, str], bool | None]] = []
|
||||||
|
|
||||||
for key, lookup_func in subcategories.items():
|
for key, lookup_func in subcategories.items():
|
||||||
lookup_value = policy.get(key)
|
lookup_value = policy.get(key)
|
||||||
|
@ -80,10 +82,10 @@ def compile_policy(
|
||||||
|
|
||||||
def _gen_dict_test_func(
|
def _gen_dict_test_func(
|
||||||
perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
|
perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
|
||||||
) -> Callable[[str, str], Optional[bool]]:
|
) -> Callable[[str, str], bool | None]:
|
||||||
"""Generate a lookup function."""
|
"""Generate a lookup function."""
|
||||||
|
|
||||||
def test_value(object_id: str, key: str) -> Optional[bool]:
|
def test_value(object_id: str, key: str) -> bool | None:
|
||||||
"""Test if permission is allowed based on the keys."""
|
"""Test if permission is allowed based on the keys."""
|
||||||
schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id)
|
schema: ValueType = lookup_func(perm_lookup, lookup_dict, object_id)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
@ -42,7 +42,7 @@ class AuthProvider:
|
||||||
DEFAULT_TITLE = "Unnamed auth provider"
|
DEFAULT_TITLE = "Unnamed auth provider"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
self, hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an auth provider."""
|
"""Initialize an auth provider."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
|
@ -50,7 +50,7 @@ class AuthProvider:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> Optional[str]:
|
def id(self) -> str | None:
|
||||||
"""Return id of the auth provider.
|
"""Return id of the auth provider.
|
||||||
|
|
||||||
Optional, can be None.
|
Optional, can be None.
|
||||||
|
@ -72,7 +72,7 @@ class AuthProvider:
|
||||||
"""Return whether multi-factor auth supported by the auth provider."""
|
"""Return whether multi-factor auth supported by the auth provider."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def async_credentials(self) -> List[Credentials]:
|
async def async_credentials(self) -> list[Credentials]:
|
||||||
"""Return all credentials of this provider."""
|
"""Return all credentials of this provider."""
|
||||||
users = await self.store.async_get_users()
|
users = await self.store.async_get_users()
|
||||||
return [
|
return [
|
||||||
|
@ -86,7 +86,7 @@ class AuthProvider:
|
||||||
]
|
]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
|
def async_create_credentials(self, data: dict[str, str]) -> Credentials:
|
||||||
"""Create credentials."""
|
"""Create credentials."""
|
||||||
return Credentials(
|
return Credentials(
|
||||||
auth_provider_type=self.type, auth_provider_id=self.id, data=data
|
auth_provider_type=self.type, auth_provider_id=self.id, data=data
|
||||||
|
@ -94,7 +94,7 @@ class AuthProvider:
|
||||||
|
|
||||||
# Implement by extending class
|
# Implement by extending class
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return the data flow for logging in with auth provider.
|
"""Return the data flow for logging in with auth provider.
|
||||||
|
|
||||||
Auth provider should extend LoginFlow and return an instance.
|
Auth provider should extend LoginFlow and return an instance.
|
||||||
|
@ -102,7 +102,7 @@ class AuthProvider:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -121,7 +121,7 @@ class AuthProvider:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_refresh_token(
|
def async_validate_refresh_token(
|
||||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: RefreshToken, remote_ip: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Verify a refresh token is still valid.
|
"""Verify a refresh token is still valid.
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ class AuthProvider:
|
||||||
|
|
||||||
|
|
||||||
async def auth_provider_from_config(
|
async def auth_provider_from_config(
|
||||||
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
hass: HomeAssistant, store: AuthStore, config: dict[str, Any]
|
||||||
) -> AuthProvider:
|
) -> AuthProvider:
|
||||||
"""Initialize an auth provider from a config."""
|
"""Initialize an auth provider from a config."""
|
||||||
provider_name = config[CONF_TYPE]
|
provider_name = config[CONF_TYPE]
|
||||||
|
@ -188,17 +188,17 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
def __init__(self, auth_provider: AuthProvider) -> None:
|
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
self._auth_module_id: Optional[str] = None
|
self._auth_module_id: str | None = None
|
||||||
self._auth_manager = auth_provider.hass.auth
|
self._auth_manager = auth_provider.hass.auth
|
||||||
self.available_mfa_modules: Dict[str, str] = {}
|
self.available_mfa_modules: dict[str, str] = {}
|
||||||
self.created_at = dt_util.utcnow()
|
self.created_at = dt_util.utcnow()
|
||||||
self.invalid_mfa_times = 0
|
self.invalid_mfa_times = 0
|
||||||
self.user: Optional[User] = None
|
self.user: User | None = None
|
||||||
self.credential: Optional[Credentials] = None
|
self.credential: Credentials | None = None
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the first step of login flow.
|
"""Handle the first step of login flow.
|
||||||
|
|
||||||
Return self.async_show_form(step_id='init') if user_input is None.
|
Return self.async_show_form(step_id='init') if user_input is None.
|
||||||
|
@ -207,8 +207,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_step_select_mfa_module(
|
async def async_step_select_mfa_module(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of select mfa module."""
|
"""Handle the step of select mfa module."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
@ -232,8 +232,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_mfa(
|
async def async_step_mfa(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of mfa validation."""
|
"""Handle the step of mfa validation."""
|
||||||
assert self.credential
|
assert self.credential
|
||||||
assert self.user
|
assert self.user
|
||||||
|
@ -273,7 +273,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
if not errors:
|
if not errors:
|
||||||
return await self.async_finish(self.credential)
|
return await self.async_finish(self.credential)
|
||||||
|
|
||||||
description_placeholders: Dict[str, Optional[str]] = {
|
description_placeholders: dict[str, str | None] = {
|
||||||
"mfa_module_name": auth_module.name,
|
"mfa_module_name": auth_module.name,
|
||||||
"mfa_module_id": auth_module.id,
|
"mfa_module_id": auth_module.id,
|
||||||
}
|
}
|
||||||
|
@ -285,6 +285,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_finish(self, flow_result: Any) -> Dict:
|
async def async_finish(self, flow_result: Any) -> dict:
|
||||||
"""Handle the pass of login flow."""
|
"""Handle the pass of login flow."""
|
||||||
return self.async_create_entry(title=self._auth_provider.name, data=flow_result)
|
return self.async_create_entry(title=self._auth_provider.name, data=flow_result)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
"""Auth provider that validates credentials via an external command."""
|
"""Auth provider that validates credentials via an external command."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio.subprocess
|
import asyncio.subprocess
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -51,9 +52,9 @@ class CommandLineAuthProvider(AuthProvider):
|
||||||
attributes provided by external programs.
|
attributes provided by external programs.
|
||||||
"""
|
"""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._user_meta: Dict[str, Dict[str, Any]] = {}
|
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return CommandLineLoginFlow(self)
|
return CommandLineLoginFlow(self)
|
||||||
|
|
||||||
|
@ -82,7 +83,7 @@ class CommandLineAuthProvider(AuthProvider):
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
if self.config[CONF_META]:
|
if self.config[CONF_META]:
|
||||||
meta: Dict[str, str] = {}
|
meta: dict[str, str] = {}
|
||||||
for _line in stdout.splitlines():
|
for _line in stdout.splitlines():
|
||||||
try:
|
try:
|
||||||
line = _line.decode().lstrip()
|
line = _line.decode().lstrip()
|
||||||
|
@ -99,7 +100,7 @@ class CommandLineAuthProvider(AuthProvider):
|
||||||
self._user_meta[username] = meta
|
self._user_meta[username] = meta
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
username = flow_result["username"]
|
username = flow_result["username"]
|
||||||
|
@ -125,8 +126,8 @@ class CommandLineLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
@ -143,7 +144,7 @@ class CommandLineLoginFlow(LoginFlow):
|
||||||
user_input.pop("password")
|
user_input.pop("password")
|
||||||
return await self.async_finish(user_input)
|
return await self.async_finish(user_input)
|
||||||
|
|
||||||
schema: Dict[str, type] = collections.OrderedDict()
|
schema: dict[str, type] = collections.OrderedDict()
|
||||||
schema["username"] = str
|
schema["username"] = str
|
||||||
schema["password"] = str
|
schema["password"] = str
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Set, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -21,7 +21,7 @@ STORAGE_VERSION = 1
|
||||||
STORAGE_KEY = "auth_provider.homeassistant"
|
STORAGE_KEY = "auth_provider.homeassistant"
|
||||||
|
|
||||||
|
|
||||||
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
|
def _disallow_id(conf: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Disallow ID in config."""
|
"""Disallow ID in config."""
|
||||||
if CONF_ID in conf:
|
if CONF_ID in conf:
|
||||||
raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
|
raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
|
||||||
|
@ -62,7 +62,7 @@ class Data:
|
||||||
self._store = hass.helpers.storage.Store(
|
self._store = hass.helpers.storage.Store(
|
||||||
STORAGE_VERSION, STORAGE_KEY, private=True
|
STORAGE_VERSION, STORAGE_KEY, private=True
|
||||||
)
|
)
|
||||||
self._data: Optional[Dict[str, Any]] = None
|
self._data: dict[str, Any] | None = None
|
||||||
# Legacy mode will allow usernames to start/end with whitespace
|
# Legacy mode will allow usernames to start/end with whitespace
|
||||||
# and will compare usernames case-insensitive.
|
# and will compare usernames case-insensitive.
|
||||||
# Remove in 2020 or when we launch 1.0.
|
# Remove in 2020 or when we launch 1.0.
|
||||||
|
@ -83,7 +83,7 @@ class Data:
|
||||||
if data is None:
|
if data is None:
|
||||||
data = {"users": []}
|
data = {"users": []}
|
||||||
|
|
||||||
seen: Set[str] = set()
|
seen: set[str] = set()
|
||||||
|
|
||||||
for user in data["users"]:
|
for user in data["users"]:
|
||||||
username = user["username"]
|
username = user["username"]
|
||||||
|
@ -121,7 +121,7 @@ class Data:
|
||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def users(self) -> List[Dict[str, str]]:
|
def users(self) -> list[dict[str, str]]:
|
||||||
"""Return users."""
|
"""Return users."""
|
||||||
return self._data["users"] # type: ignore
|
return self._data["users"] # type: ignore
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ class HassAuthProvider(AuthProvider):
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Initialize an Home Assistant auth provider."""
|
"""Initialize an Home Assistant auth provider."""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.data: Optional[Data] = None
|
self.data: Data | None = None
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def async_initialize(self) -> None:
|
async def async_initialize(self) -> None:
|
||||||
|
@ -233,7 +233,7 @@ class HassAuthProvider(AuthProvider):
|
||||||
await data.async_load()
|
await data.async_load()
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return HassLoginFlow(self)
|
return HassLoginFlow(self)
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ class HassAuthProvider(AuthProvider):
|
||||||
await self.data.async_save()
|
await self.data.async_save()
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
if self.data is None:
|
if self.data is None:
|
||||||
|
@ -318,8 +318,8 @@ class HassLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ class HassLoginFlow(LoginFlow):
|
||||||
user_input.pop("password")
|
user_input.pop("password")
|
||||||
return await self.async_finish(user_input)
|
return await self.async_finish(user_input)
|
||||||
|
|
||||||
schema: Dict[str, type] = OrderedDict()
|
schema: dict[str, type] = OrderedDict()
|
||||||
schema["username"] = str
|
schema["username"] = str
|
||||||
schema["password"] = str
|
schema["password"] = str
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Example auth provider."""
|
"""Example auth provider."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -33,7 +35,7 @@ class InvalidAuthError(HomeAssistantError):
|
||||||
class ExampleAuthProvider(AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return ExampleLoginFlow(self)
|
return ExampleLoginFlow(self)
|
||||||
|
|
||||||
|
@ -60,7 +62,7 @@ class ExampleAuthProvider(AuthProvider):
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
username = flow_result["username"]
|
username = flow_result["username"]
|
||||||
|
@ -94,8 +96,8 @@ class ExampleLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
@ -111,7 +113,7 @@ class ExampleLoginFlow(LoginFlow):
|
||||||
user_input.pop("password")
|
user_input.pop("password")
|
||||||
return await self.async_finish(user_input)
|
return await self.async_finish(user_input)
|
||||||
|
|
||||||
schema: Dict[str, type] = OrderedDict()
|
schema: dict[str, type] = OrderedDict()
|
||||||
schema["username"] = str
|
schema["username"] = str
|
||||||
schema["password"] = str
|
schema["password"] = str
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,10 @@ Support Legacy API password auth provider.
|
||||||
|
|
||||||
It will be removed when auth system production ready
|
It will be removed when auth system production ready
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -40,7 +42,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
"""Return api_password."""
|
"""Return api_password."""
|
||||||
return str(self.config[CONF_API_PASSWORD])
|
return str(self.config[CONF_API_PASSWORD])
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return LegacyLoginFlow(self)
|
return LegacyLoginFlow(self)
|
||||||
|
|
||||||
|
@ -55,7 +57,7 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Return credentials for this login."""
|
"""Return credentials for this login."""
|
||||||
credentials = await self.async_credentials()
|
credentials = await self.async_credentials()
|
||||||
|
@ -79,8 +81,8 @@ class LegacyLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
errors = {}
|
errors = {}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
It shows list of users if access from trusted network.
|
It shows list of users if access from trusted network.
|
||||||
Abort login flow if not access from trusted network.
|
Abort login flow if not access from trusted network.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from ipaddress import (
|
from ipaddress import (
|
||||||
IPv4Address,
|
IPv4Address,
|
||||||
IPv4Network,
|
IPv4Network,
|
||||||
|
@ -11,7 +13,7 @@ from ipaddress import (
|
||||||
ip_address,
|
ip_address,
|
||||||
ip_network,
|
ip_network,
|
||||||
)
|
)
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Union, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
@ -68,12 +70,12 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
DEFAULT_TITLE = "Trusted Networks"
|
DEFAULT_TITLE = "Trusted Networks"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def trusted_networks(self) -> List[IPNetwork]:
|
def trusted_networks(self) -> list[IPNetwork]:
|
||||||
"""Return trusted networks."""
|
"""Return trusted networks."""
|
||||||
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
|
return cast(List[IPNetwork], self.config[CONF_TRUSTED_NETWORKS])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def trusted_users(self) -> Dict[IPNetwork, Any]:
|
def trusted_users(self) -> dict[IPNetwork, Any]:
|
||||||
"""Return trusted users per network."""
|
"""Return trusted users per network."""
|
||||||
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
|
return cast(Dict[IPNetwork, Any], self.config[CONF_TRUSTED_USERS])
|
||||||
|
|
||||||
|
@ -82,7 +84,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
"""Trusted Networks auth provider does not support MFA."""
|
"""Trusted Networks auth provider does not support MFA."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
async def async_login_flow(self, context: dict | None) -> LoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
assert context is not None
|
assert context is not None
|
||||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||||
|
@ -125,7 +127,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_or_create_credentials(
|
async def async_get_or_create_credentials(
|
||||||
self, flow_result: Dict[str, str]
|
self, flow_result: dict[str, str]
|
||||||
) -> Credentials:
|
) -> Credentials:
|
||||||
"""Get credentials based on the flow result."""
|
"""Get credentials based on the flow result."""
|
||||||
user_id = flow_result["user"]
|
user_id = flow_result["user"]
|
||||||
|
@ -169,7 +171,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_validate_refresh_token(
|
def async_validate_refresh_token(
|
||||||
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: RefreshToken, remote_ip: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Verify a refresh token is still valid."""
|
"""Verify a refresh token is still valid."""
|
||||||
if remote_ip is None:
|
if remote_ip is None:
|
||||||
|
@ -186,7 +188,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||||
self,
|
self,
|
||||||
auth_provider: TrustedNetworksAuthProvider,
|
auth_provider: TrustedNetworksAuthProvider,
|
||||||
ip_addr: IPAddress,
|
ip_addr: IPAddress,
|
||||||
available_users: Dict[str, Optional[str]],
|
available_users: dict[str, str | None],
|
||||||
allow_bypass_login: bool,
|
allow_bypass_login: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
|
@ -196,8 +198,8 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||||
self._allow_bypass_login = allow_bypass_login
|
self._allow_bypass_login = allow_bypass_login
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
try:
|
try:
|
||||||
cast(
|
cast(
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
"""Home Assistant command line scripts."""
|
"""Home Assistant command line scripts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional, Sequence, Text
|
from typing import Sequence
|
||||||
|
|
||||||
from homeassistant import runner
|
from homeassistant import runner
|
||||||
from homeassistant.bootstrap import async_mount_local_lib_path
|
from homeassistant.bootstrap import async_mount_local_lib_path
|
||||||
|
@ -16,7 +18,7 @@ from homeassistant.util.package import install_package, is_installed, is_virtual
|
||||||
# mypy: allow-untyped-defs, no-warn-return-any
|
# mypy: allow-untyped-defs, no-warn-return-any
|
||||||
|
|
||||||
|
|
||||||
def run(args: List) -> int:
|
def run(args: list) -> int:
|
||||||
"""Run a script."""
|
"""Run a script."""
|
||||||
scripts = []
|
scripts = []
|
||||||
path = os.path.dirname(__file__)
|
path = os.path.dirname(__file__)
|
||||||
|
@ -65,7 +67,7 @@ def run(args: List) -> int:
|
||||||
return script.run(args[1:]) # type: ignore
|
return script.run(args[1:]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def extract_config_dir(args: Optional[Sequence[Text]] = None) -> str:
|
def extract_config_dir(args: Sequence[str] | None = None) -> str:
|
||||||
"""Extract the config dir from the arguments or get the default."""
|
"""Extract the config dir from the arguments or get the default."""
|
||||||
parser = argparse.ArgumentParser(add_help=False)
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
parser.add_argument("-c", "--config", default=None)
|
parser.add_argument("-c", "--config", default=None)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Script to run benchmarks."""
|
"""Script to run benchmarks."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
@ -7,7 +9,7 @@ from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from typing import Callable, Dict, TypeVar
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
from homeassistant.components.websocket_api.const import JSON_DUMP
|
from homeassistant.components.websocket_api.const import JSON_DUMP
|
||||||
|
@ -21,7 +23,7 @@ from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
|
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
|
||||||
|
|
||||||
BENCHMARKS: Dict[str, Callable] = {}
|
BENCHMARKS: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
def run(args):
|
def run(args):
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Script to check the configuration file."""
|
"""Script to check the configuration file."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -6,7 +8,7 @@ from collections.abc import Mapping, Sequence
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Tuple
|
from typing import Any, Callable
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
|
@ -22,13 +24,13 @@ REQUIREMENTS = ("colorlog==4.7.2",)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
MOCKS: Dict[str, Tuple[str, Callable]] = {
|
MOCKS: dict[str, tuple[str, Callable]] = {
|
||||||
"load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml),
|
"load": ("homeassistant.util.yaml.loader.load_yaml", yaml_loader.load_yaml),
|
||||||
"load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml),
|
"load*": ("homeassistant.config.load_yaml", yaml_loader.load_yaml),
|
||||||
"secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml),
|
"secrets": ("homeassistant.util.yaml.loader.secret_yaml", yaml_loader.secret_yaml),
|
||||||
}
|
}
|
||||||
|
|
||||||
PATCHES: Dict[str, Any] = {}
|
PATCHES: dict[str, Any] = {}
|
||||||
|
|
||||||
C_HEAD = "bold"
|
C_HEAD = "bold"
|
||||||
ERROR_STR = "General Errors"
|
ERROR_STR = "General Errors"
|
||||||
|
@ -48,7 +50,7 @@ def color(the_color, *args, reset=None):
|
||||||
raise ValueError(f"Invalid color {k!s} in {the_color}") from k
|
raise ValueError(f"Invalid color {k!s} in {the_color}") from k
|
||||||
|
|
||||||
|
|
||||||
def run(script_args: List) -> int:
|
def run(script_args: list) -> int:
|
||||||
"""Handle check config commandline script."""
|
"""Handle check config commandline script."""
|
||||||
parser = argparse.ArgumentParser(description="Check Home Assistant configuration.")
|
parser = argparse.ArgumentParser(description="Check Home Assistant configuration.")
|
||||||
parser.add_argument("--script", choices=["check_config"])
|
parser.add_argument("--script", choices=["check_config"])
|
||||||
|
@ -83,7 +85,7 @@ def run(script_args: List) -> int:
|
||||||
|
|
||||||
res = check(config_dir, args.secrets)
|
res = check(config_dir, args.secrets)
|
||||||
|
|
||||||
domain_info: List[str] = []
|
domain_info: list[str] = []
|
||||||
if args.info:
|
if args.info:
|
||||||
domain_info = args.info.split(",")
|
domain_info = args.info.split(",")
|
||||||
|
|
||||||
|
@ -123,7 +125,7 @@ def run(script_args: List) -> int:
|
||||||
dump_dict(res["components"].get(domain))
|
dump_dict(res["components"].get(domain))
|
||||||
|
|
||||||
if args.secrets:
|
if args.secrets:
|
||||||
flatsecret: Dict[str, str] = {}
|
flatsecret: dict[str, str] = {}
|
||||||
|
|
||||||
for sfn, sdict in res["secret_cache"].items():
|
for sfn, sdict in res["secret_cache"].items():
|
||||||
sss = []
|
sss = []
|
||||||
|
@ -149,7 +151,7 @@ def run(script_args: List) -> int:
|
||||||
def check(config_dir, secrets=False):
|
def check(config_dir, secrets=False):
|
||||||
"""Perform a check by mocking hass load functions."""
|
"""Perform a check by mocking hass load functions."""
|
||||||
logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL)
|
logging.getLogger("homeassistant.loader").setLevel(logging.CRITICAL)
|
||||||
res: Dict[str, Any] = {
|
res: dict[str, Any] = {
|
||||||
"yaml_files": OrderedDict(), # yaml_files loaded
|
"yaml_files": OrderedDict(), # yaml_files loaded
|
||||||
"secrets": OrderedDict(), # secret cache and secrets loaded
|
"secrets": OrderedDict(), # secret cache and secrets loaded
|
||||||
"except": OrderedDict(), # exceptions raised (with config)
|
"except": OrderedDict(), # exceptions raised (with config)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Helper methods for various modules."""
|
"""Helper methods for various modules."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import enum
|
import enum
|
||||||
|
@ -9,16 +11,7 @@ import socket
|
||||||
import string
|
import string
|
||||||
import threading
|
import threading
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (
|
from typing import Any, Callable, Coroutine, Iterable, KeysView, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Coroutine,
|
|
||||||
Iterable,
|
|
||||||
KeysView,
|
|
||||||
Optional,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import slugify as unicode_slug
|
import slugify as unicode_slug
|
||||||
|
|
||||||
|
@ -106,8 +99,8 @@ def repr_helper(inp: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
def convert(
|
def convert(
|
||||||
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
|
value: T | None, to_type: Callable[[T], U], default: U | None = None
|
||||||
) -> Optional[U]:
|
) -> U | None:
|
||||||
"""Convert value to to_type, returns default if fails."""
|
"""Convert value to to_type, returns default if fails."""
|
||||||
try:
|
try:
|
||||||
return default if value is None else to_type(value)
|
return default if value is None else to_type(value)
|
||||||
|
@ -117,7 +110,7 @@ def convert(
|
||||||
|
|
||||||
|
|
||||||
def ensure_unique_string(
|
def ensure_unique_string(
|
||||||
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
|
preferred_string: str, current_strings: Iterable[str] | KeysView[str]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return a string that is not present in current_strings.
|
"""Return a string that is not present in current_strings.
|
||||||
|
|
||||||
|
@ -213,7 +206,7 @@ class Throttle:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
|
self, min_time: timedelta, limit_no_throttle: timedelta | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the throttle."""
|
"""Initialize the throttle."""
|
||||||
self.min_time = min_time
|
self.min_time = min_time
|
||||||
|
@ -253,7 +246,7 @@ class Throttle:
|
||||||
)
|
)
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
|
def wrapper(*args: Any, **kwargs: Any) -> Callable | Coroutine:
|
||||||
"""Wrap that allows wrapped to be called only once per min_time.
|
"""Wrap that allows wrapped to be called only once per min_time.
|
||||||
|
|
||||||
If we cannot acquire the lock, it is running so return None.
|
If we cannot acquire the lock, it is running so return None.
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Utilities to help with aiohttp."""
|
"""Utilities to help with aiohttp."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
from urllib.parse import parse_qsl
|
from urllib.parse import parse_qsl
|
||||||
|
|
||||||
from multidict import CIMultiDict, MultiDict
|
from multidict import CIMultiDict, MultiDict
|
||||||
|
@ -26,7 +28,7 @@ class MockStreamReader:
|
||||||
class MockRequest:
|
class MockRequest:
|
||||||
"""Mock an aiohttp request."""
|
"""Mock an aiohttp request."""
|
||||||
|
|
||||||
mock_source: Optional[str] = None
|
mock_source: str | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -34,8 +36,8 @@ class MockRequest:
|
||||||
mock_source: str,
|
mock_source: str,
|
||||||
method: str = "GET",
|
method: str = "GET",
|
||||||
status: int = HTTP_OK,
|
status: int = HTTP_OK,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: dict[str, str] | None = None,
|
||||||
query_string: Optional[str] = None,
|
query_string: str | None = None,
|
||||||
url: str = "",
|
url: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a request."""
|
"""Initialize a request."""
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
"""Color util methods."""
|
"""Color util methods."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import colorsys
|
import colorsys
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -183,7 +184,7 @@ class GamutType:
|
||||||
blue: XYPoint = attr.ib()
|
blue: XYPoint = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
|
def color_name_to_rgb(color_name: str) -> tuple[int, int, int]:
|
||||||
"""Convert color name to RGB hex value."""
|
"""Convert color name to RGB hex value."""
|
||||||
# COLORS map has no spaces in it, so make the color_name have no
|
# COLORS map has no spaces in it, so make the color_name have no
|
||||||
# spaces in it as well for matching purposes
|
# spaces in it as well for matching purposes
|
||||||
|
@ -198,8 +199,8 @@ def color_name_to_rgb(color_name: str) -> Tuple[int, int, int]:
|
||||||
|
|
||||||
|
|
||||||
def color_RGB_to_xy(
|
def color_RGB_to_xy(
|
||||||
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
|
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
|
||||||
) -> Tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Convert from RGB color to XY color."""
|
"""Convert from RGB color to XY color."""
|
||||||
return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2]
|
return color_RGB_to_xy_brightness(iR, iG, iB, Gamut)[:2]
|
||||||
|
|
||||||
|
@ -208,8 +209,8 @@ def color_RGB_to_xy(
|
||||||
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
||||||
# License: Code is given as is. Use at your own risk and discretion.
|
# License: Code is given as is. Use at your own risk and discretion.
|
||||||
def color_RGB_to_xy_brightness(
|
def color_RGB_to_xy_brightness(
|
||||||
iR: int, iG: int, iB: int, Gamut: Optional[GamutType] = None
|
iR: int, iG: int, iB: int, Gamut: GamutType | None = None
|
||||||
) -> Tuple[float, float, int]:
|
) -> tuple[float, float, int]:
|
||||||
"""Convert from RGB color to XY color."""
|
"""Convert from RGB color to XY color."""
|
||||||
if iR + iG + iB == 0:
|
if iR + iG + iB == 0:
|
||||||
return 0.0, 0.0, 0
|
return 0.0, 0.0, 0
|
||||||
|
@ -248,8 +249,8 @@ def color_RGB_to_xy_brightness(
|
||||||
|
|
||||||
|
|
||||||
def color_xy_to_RGB(
|
def color_xy_to_RGB(
|
||||||
vX: float, vY: float, Gamut: Optional[GamutType] = None
|
vX: float, vY: float, Gamut: GamutType | None = None
|
||||||
) -> Tuple[int, int, int]:
|
) -> tuple[int, int, int]:
|
||||||
"""Convert from XY to a normalized RGB."""
|
"""Convert from XY to a normalized RGB."""
|
||||||
return color_xy_brightness_to_RGB(vX, vY, 255, Gamut)
|
return color_xy_brightness_to_RGB(vX, vY, 255, Gamut)
|
||||||
|
|
||||||
|
@ -257,8 +258,8 @@ def color_xy_to_RGB(
|
||||||
# Converted to Python from Obj-C, original source from:
|
# Converted to Python from Obj-C, original source from:
|
||||||
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
# http://www.developers.meethue.com/documentation/color-conversions-rgb-xy
|
||||||
def color_xy_brightness_to_RGB(
|
def color_xy_brightness_to_RGB(
|
||||||
vX: float, vY: float, ibrightness: int, Gamut: Optional[GamutType] = None
|
vX: float, vY: float, ibrightness: int, Gamut: GamutType | None = None
|
||||||
) -> Tuple[int, int, int]:
|
) -> tuple[int, int, int]:
|
||||||
"""Convert from XYZ to RGB."""
|
"""Convert from XYZ to RGB."""
|
||||||
if Gamut:
|
if Gamut:
|
||||||
if not check_point_in_lamps_reach((vX, vY), Gamut):
|
if not check_point_in_lamps_reach((vX, vY), Gamut):
|
||||||
|
@ -304,7 +305,7 @@ def color_xy_brightness_to_RGB(
|
||||||
return (ir, ig, ib)
|
return (ir, ig, ib)
|
||||||
|
|
||||||
|
|
||||||
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
|
def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> tuple[int, int, int]:
|
||||||
"""Convert a hsb into its rgb representation."""
|
"""Convert a hsb into its rgb representation."""
|
||||||
if fS == 0.0:
|
if fS == 0.0:
|
||||||
fV = int(fB * 255)
|
fV = int(fB * 255)
|
||||||
|
@ -345,7 +346,7 @@ def color_hsb_to_RGB(fH: float, fS: float, fB: float) -> Tuple[int, int, int]:
|
||||||
return (r, g, b)
|
return (r, g, b)
|
||||||
|
|
||||||
|
|
||||||
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, float]:
|
def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> tuple[float, float, float]:
|
||||||
"""Convert an rgb color to its hsv representation.
|
"""Convert an rgb color to its hsv representation.
|
||||||
|
|
||||||
Hue is scaled 0-360
|
Hue is scaled 0-360
|
||||||
|
@ -356,12 +357,12 @@ def color_RGB_to_hsv(iR: float, iG: float, iB: float) -> Tuple[float, float, flo
|
||||||
return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3)
|
return round(fHSV[0] * 360, 3), round(fHSV[1] * 100, 3), round(fHSV[2] * 100, 3)
|
||||||
|
|
||||||
|
|
||||||
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> Tuple[float, float]:
|
def color_RGB_to_hs(iR: float, iG: float, iB: float) -> tuple[float, float]:
|
||||||
"""Convert an rgb color to its hs representation."""
|
"""Convert an rgb color to its hs representation."""
|
||||||
return color_RGB_to_hsv(iR, iG, iB)[:2]
|
return color_RGB_to_hsv(iR, iG, iB)[:2]
|
||||||
|
|
||||||
|
|
||||||
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
|
def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> tuple[int, int, int]:
|
||||||
"""Convert an hsv color into its rgb representation.
|
"""Convert an hsv color into its rgb representation.
|
||||||
|
|
||||||
Hue is scaled 0-360
|
Hue is scaled 0-360
|
||||||
|
@ -372,27 +373,27 @@ def color_hsv_to_RGB(iH: float, iS: float, iV: float) -> Tuple[int, int, int]:
|
||||||
return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255))
|
return (int(fRGB[0] * 255), int(fRGB[1] * 255), int(fRGB[2] * 255))
|
||||||
|
|
||||||
|
|
||||||
def color_hs_to_RGB(iH: float, iS: float) -> Tuple[int, int, int]:
|
def color_hs_to_RGB(iH: float, iS: float) -> tuple[int, int, int]:
|
||||||
"""Convert an hsv color into its rgb representation."""
|
"""Convert an hsv color into its rgb representation."""
|
||||||
return color_hsv_to_RGB(iH, iS, 100)
|
return color_hsv_to_RGB(iH, iS, 100)
|
||||||
|
|
||||||
|
|
||||||
def color_xy_to_hs(
|
def color_xy_to_hs(
|
||||||
vX: float, vY: float, Gamut: Optional[GamutType] = None
|
vX: float, vY: float, Gamut: GamutType | None = None
|
||||||
) -> Tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Convert an xy color to its hs representation."""
|
"""Convert an xy color to its hs representation."""
|
||||||
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut))
|
h, s, _ = color_RGB_to_hsv(*color_xy_to_RGB(vX, vY, Gamut))
|
||||||
return h, s
|
return h, s
|
||||||
|
|
||||||
|
|
||||||
def color_hs_to_xy(
|
def color_hs_to_xy(
|
||||||
iH: float, iS: float, Gamut: Optional[GamutType] = None
|
iH: float, iS: float, Gamut: GamutType | None = None
|
||||||
) -> Tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""Convert an hs color to its xy representation."""
|
"""Convert an hs color to its xy representation."""
|
||||||
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut)
|
return color_RGB_to_xy(*color_hs_to_RGB(iH, iS), Gamut)
|
||||||
|
|
||||||
|
|
||||||
def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
|
def _match_max_scale(input_colors: tuple, output_colors: tuple) -> tuple:
|
||||||
"""Match the maximum value of the output to the input."""
|
"""Match the maximum value of the output to the input."""
|
||||||
max_in = max(input_colors)
|
max_in = max(input_colors)
|
||||||
max_out = max(output_colors)
|
max_out = max(output_colors)
|
||||||
|
@ -403,7 +404,7 @@ def _match_max_scale(input_colors: Tuple, output_colors: Tuple) -> Tuple:
|
||||||
return tuple(int(round(i * factor)) for i in output_colors)
|
return tuple(int(round(i * factor)) for i in output_colors)
|
||||||
|
|
||||||
|
|
||||||
def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
|
def color_rgb_to_rgbw(r: int, g: int, b: int) -> tuple[int, int, int, int]:
|
||||||
"""Convert an rgb color to an rgbw representation."""
|
"""Convert an rgb color to an rgbw representation."""
|
||||||
# Calculate the white channel as the minimum of input rgb channels.
|
# Calculate the white channel as the minimum of input rgb channels.
|
||||||
# Subtract the white portion from the remaining rgb channels.
|
# Subtract the white portion from the remaining rgb channels.
|
||||||
|
@ -415,7 +416,7 @@ def color_rgb_to_rgbw(r: int, g: int, b: int) -> Tuple[int, int, int, int]:
|
||||||
return _match_max_scale((r, g, b), rgbw) # type: ignore
|
return _match_max_scale((r, g, b), rgbw) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> Tuple[int, int, int]:
|
def color_rgbw_to_rgb(r: int, g: int, b: int, w: int) -> tuple[int, int, int]:
|
||||||
"""Convert an rgbw color to an rgb representation."""
|
"""Convert an rgbw color to an rgb representation."""
|
||||||
# Add the white channel back into the rgb channels.
|
# Add the white channel back into the rgb channels.
|
||||||
rgb = (r + w, g + w, b + w)
|
rgb = (r + w, g + w, b + w)
|
||||||
|
@ -430,7 +431,7 @@ def color_rgb_to_hex(r: int, g: int, b: int) -> str:
|
||||||
return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b))
|
return "{:02x}{:02x}{:02x}".format(round(r), round(g), round(b))
|
||||||
|
|
||||||
|
|
||||||
def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
|
def rgb_hex_to_rgb_list(hex_string: str) -> list[int]:
|
||||||
"""Return an RGB color value list from a hex color string."""
|
"""Return an RGB color value list from a hex color string."""
|
||||||
return [
|
return [
|
||||||
int(hex_string[i : i + len(hex_string) // 3], 16)
|
int(hex_string[i : i + len(hex_string) // 3], 16)
|
||||||
|
@ -438,14 +439,14 @@ def rgb_hex_to_rgb_list(hex_string: str) -> List[int]:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_to_hs(color_temperature_kelvin: float) -> Tuple[float, float]:
|
def color_temperature_to_hs(color_temperature_kelvin: float) -> tuple[float, float]:
|
||||||
"""Return an hs color from a color temperature in Kelvin."""
|
"""Return an hs color from a color temperature in Kelvin."""
|
||||||
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
|
return color_RGB_to_hs(*color_temperature_to_rgb(color_temperature_kelvin))
|
||||||
|
|
||||||
|
|
||||||
def color_temperature_to_rgb(
|
def color_temperature_to_rgb(
|
||||||
color_temperature_kelvin: float,
|
color_temperature_kelvin: float,
|
||||||
) -> Tuple[float, float, float]:
|
) -> tuple[float, float, float]:
|
||||||
"""
|
"""
|
||||||
Return an RGB color from a color temperature in Kelvin.
|
Return an RGB color from a color temperature in Kelvin.
|
||||||
|
|
||||||
|
@ -555,8 +556,8 @@ def get_closest_point_to_line(A: XYPoint, B: XYPoint, P: XYPoint) -> XYPoint:
|
||||||
|
|
||||||
|
|
||||||
def get_closest_point_to_point(
|
def get_closest_point_to_point(
|
||||||
xy_tuple: Tuple[float, float], Gamut: GamutType
|
xy_tuple: tuple[float, float], Gamut: GamutType
|
||||||
) -> Tuple[float, float]:
|
) -> tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Get the closest matching color within the gamut of the light.
|
Get the closest matching color within the gamut of the light.
|
||||||
|
|
||||||
|
@ -592,7 +593,7 @@ def get_closest_point_to_point(
|
||||||
return (cx, cy)
|
return (cx, cy)
|
||||||
|
|
||||||
|
|
||||||
def check_point_in_lamps_reach(p: Tuple[float, float], Gamut: GamutType) -> bool:
|
def check_point_in_lamps_reach(p: tuple[float, float], Gamut: GamutType) -> bool:
|
||||||
"""Check if the provided XYPoint can be recreated by a Hue lamp."""
|
"""Check if the provided XYPoint can be recreated by a Hue lamp."""
|
||||||
v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y)
|
v1 = XYPoint(Gamut.green.x - Gamut.red.x, Gamut.green.y - Gamut.red.y)
|
||||||
v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y)
|
v2 = XYPoint(Gamut.blue.x - Gamut.red.x, Gamut.blue.y - Gamut.red.y)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Distance util functions."""
|
"""Distance util functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Callable, Dict
|
from typing import Callable
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
LENGTH,
|
LENGTH,
|
||||||
|
@ -26,7 +28,7 @@ VALID_UNITS = [
|
||||||
LENGTH_YARD,
|
LENGTH_YARD,
|
||||||
]
|
]
|
||||||
|
|
||||||
TO_METERS: Dict[str, Callable[[float], float]] = {
|
TO_METERS: dict[str, Callable[[float], float]] = {
|
||||||
LENGTH_METERS: lambda meters: meters,
|
LENGTH_METERS: lambda meters: meters,
|
||||||
LENGTH_MILES: lambda miles: miles * 1609.344,
|
LENGTH_MILES: lambda miles: miles * 1609.344,
|
||||||
LENGTH_YARD: lambda yards: yards * 0.9144,
|
LENGTH_YARD: lambda yards: yards * 0.9144,
|
||||||
|
@ -37,7 +39,7 @@ TO_METERS: Dict[str, Callable[[float], float]] = {
|
||||||
LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001,
|
LENGTH_MILLIMETERS: lambda millimeters: millimeters * 0.001,
|
||||||
}
|
}
|
||||||
|
|
||||||
METERS_TO: Dict[str, Callable[[float], float]] = {
|
METERS_TO: dict[str, Callable[[float], float]] = {
|
||||||
LENGTH_METERS: lambda meters: meters,
|
LENGTH_METERS: lambda meters: meters,
|
||||||
LENGTH_MILES: lambda meters: meters * 0.000621371,
|
LENGTH_MILES: lambda meters: meters * 0.000621371,
|
||||||
LENGTH_YARD: lambda meters: meters * 1.09361,
|
LENGTH_YARD: lambda meters: meters * 1.09361,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""Helper methods to handle the time in Home Assistant."""
|
"""Helper methods to handle the time in Home Assistant."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import ciso8601
|
import ciso8601
|
||||||
import pytz
|
import pytz
|
||||||
|
@ -40,7 +42,7 @@ def set_default_time_zone(time_zone: dt.tzinfo) -> None:
|
||||||
DEFAULT_TIME_ZONE = time_zone
|
DEFAULT_TIME_ZONE = time_zone
|
||||||
|
|
||||||
|
|
||||||
def get_time_zone(time_zone_str: str) -> Optional[dt.tzinfo]:
|
def get_time_zone(time_zone_str: str) -> dt.tzinfo | None:
|
||||||
"""Get time zone from string. Return None if unable to determine.
|
"""Get time zone from string. Return None if unable to determine.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -56,7 +58,7 @@ def utcnow() -> dt.datetime:
|
||||||
return dt.datetime.now(NATIVE_UTC)
|
return dt.datetime.now(NATIVE_UTC)
|
||||||
|
|
||||||
|
|
||||||
def now(time_zone: Optional[dt.tzinfo] = None) -> dt.datetime:
|
def now(time_zone: dt.tzinfo | None = None) -> dt.datetime:
|
||||||
"""Get now in specified time zone."""
|
"""Get now in specified time zone."""
|
||||||
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
|
return dt.datetime.now(time_zone or DEFAULT_TIME_ZONE)
|
||||||
|
|
||||||
|
@ -77,7 +79,7 @@ def as_utc(dattim: dt.datetime) -> dt.datetime:
|
||||||
def as_timestamp(dt_value: dt.datetime) -> float:
|
def as_timestamp(dt_value: dt.datetime) -> float:
|
||||||
"""Convert a date/time into a unix time (seconds since 1970)."""
|
"""Convert a date/time into a unix time (seconds since 1970)."""
|
||||||
if hasattr(dt_value, "timestamp"):
|
if hasattr(dt_value, "timestamp"):
|
||||||
parsed_dt: Optional[dt.datetime] = dt_value
|
parsed_dt: dt.datetime | None = dt_value
|
||||||
else:
|
else:
|
||||||
parsed_dt = parse_datetime(str(dt_value))
|
parsed_dt = parse_datetime(str(dt_value))
|
||||||
if parsed_dt is None:
|
if parsed_dt is None:
|
||||||
|
@ -100,9 +102,7 @@ def utc_from_timestamp(timestamp: float) -> dt.datetime:
|
||||||
return UTC.localize(dt.datetime.utcfromtimestamp(timestamp))
|
return UTC.localize(dt.datetime.utcfromtimestamp(timestamp))
|
||||||
|
|
||||||
|
|
||||||
def start_of_local_day(
|
def start_of_local_day(dt_or_d: dt.date | dt.datetime | None = None) -> dt.datetime:
|
||||||
dt_or_d: Union[dt.date, dt.datetime, None] = None
|
|
||||||
) -> dt.datetime:
|
|
||||||
"""Return local datetime object of start of day from date or datetime."""
|
"""Return local datetime object of start of day from date or datetime."""
|
||||||
if dt_or_d is None:
|
if dt_or_d is None:
|
||||||
date: dt.date = now().date()
|
date: dt.date = now().date()
|
||||||
|
@ -119,7 +119,7 @@ def start_of_local_day(
|
||||||
# Copyright (c) Django Software Foundation and individual contributors.
|
# Copyright (c) Django Software Foundation and individual contributors.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
# https://github.com/django/django/blob/master/LICENSE
|
# https://github.com/django/django/blob/master/LICENSE
|
||||||
def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
def parse_datetime(dt_str: str) -> dt.datetime | None:
|
||||||
"""Parse a string and return a datetime.datetime.
|
"""Parse a string and return a datetime.datetime.
|
||||||
|
|
||||||
This function supports time zone offsets. When the input contains one,
|
This function supports time zone offsets. When the input contains one,
|
||||||
|
@ -134,12 +134,12 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
||||||
match = DATETIME_RE.match(dt_str)
|
match = DATETIME_RE.match(dt_str)
|
||||||
if not match:
|
if not match:
|
||||||
return None
|
return None
|
||||||
kws: Dict[str, Any] = match.groupdict()
|
kws: dict[str, Any] = match.groupdict()
|
||||||
if kws["microsecond"]:
|
if kws["microsecond"]:
|
||||||
kws["microsecond"] = kws["microsecond"].ljust(6, "0")
|
kws["microsecond"] = kws["microsecond"].ljust(6, "0")
|
||||||
tzinfo_str = kws.pop("tzinfo")
|
tzinfo_str = kws.pop("tzinfo")
|
||||||
|
|
||||||
tzinfo: Optional[dt.tzinfo] = None
|
tzinfo: dt.tzinfo | None = None
|
||||||
if tzinfo_str == "Z":
|
if tzinfo_str == "Z":
|
||||||
tzinfo = UTC
|
tzinfo = UTC
|
||||||
elif tzinfo_str is not None:
|
elif tzinfo_str is not None:
|
||||||
|
@ -154,7 +154,7 @@ def parse_datetime(dt_str: str) -> Optional[dt.datetime]:
|
||||||
return dt.datetime(**kws)
|
return dt.datetime(**kws)
|
||||||
|
|
||||||
|
|
||||||
def parse_date(dt_str: str) -> Optional[dt.date]:
|
def parse_date(dt_str: str) -> dt.date | None:
|
||||||
"""Convert a date string to a date object."""
|
"""Convert a date string to a date object."""
|
||||||
try:
|
try:
|
||||||
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
|
return dt.datetime.strptime(dt_str, DATE_STR_FORMAT).date()
|
||||||
|
@ -162,7 +162,7 @@ def parse_date(dt_str: str) -> Optional[dt.date]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def parse_time(time_str: str) -> Optional[dt.time]:
|
def parse_time(time_str: str) -> dt.time | None:
|
||||||
"""Parse a time string (00:20:00) into Time object.
|
"""Parse a time string (00:20:00) into Time object.
|
||||||
|
|
||||||
Return None if invalid.
|
Return None if invalid.
|
||||||
|
@ -213,7 +213,7 @@ def get_age(date: dt.datetime) -> str:
|
||||||
return formatn(rounded_delta, selected_unit)
|
return formatn(rounded_delta, selected_unit)
|
||||||
|
|
||||||
|
|
||||||
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> List[int]:
|
def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> list[int]:
|
||||||
"""Parse the time expression part and return a list of times to match."""
|
"""Parse the time expression part and return a list of times to match."""
|
||||||
if parameter is None or parameter == MATCH_ALL:
|
if parameter is None or parameter == MATCH_ALL:
|
||||||
res = list(range(min_value, max_value + 1))
|
res = list(range(min_value, max_value + 1))
|
||||||
|
@ -241,9 +241,9 @@ def parse_time_expression(parameter: Any, min_value: int, max_value: int) -> Lis
|
||||||
|
|
||||||
def find_next_time_expression_time(
|
def find_next_time_expression_time(
|
||||||
now: dt.datetime, # pylint: disable=redefined-outer-name
|
now: dt.datetime, # pylint: disable=redefined-outer-name
|
||||||
seconds: List[int],
|
seconds: list[int],
|
||||||
minutes: List[int],
|
minutes: list[int],
|
||||||
hours: List[int],
|
hours: list[int],
|
||||||
) -> dt.datetime:
|
) -> dt.datetime:
|
||||||
"""Find the next datetime from now for which the time expression matches.
|
"""Find the next datetime from now for which the time expression matches.
|
||||||
|
|
||||||
|
@ -257,7 +257,7 @@ def find_next_time_expression_time(
|
||||||
if not seconds or not minutes or not hours:
|
if not seconds or not minutes or not hours:
|
||||||
raise ValueError("Cannot find a next time: Time expression never matches!")
|
raise ValueError("Cannot find a next time: Time expression never matches!")
|
||||||
|
|
||||||
def _lower_bound(arr: List[int], cmp: int) -> Optional[int]:
|
def _lower_bound(arr: list[int], cmp: int) -> int | None:
|
||||||
"""Return the first value in arr greater or equal to cmp.
|
"""Return the first value in arr greater or equal to cmp.
|
||||||
|
|
||||||
Return None if no such value exists.
|
Return None if no such value exists.
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
"""JSON utility functions."""
|
"""JSON utility functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable
|
||||||
|
|
||||||
from homeassistant.core import Event, State
|
from homeassistant.core import Event, State
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
@ -20,9 +22,7 @@ class WriteError(HomeAssistantError):
|
||||||
"""Error writing the data."""
|
"""Error writing the data."""
|
||||||
|
|
||||||
|
|
||||||
def load_json(
|
def load_json(filename: str, default: list | dict | None = None) -> list | dict:
|
||||||
filename: str, default: Union[List, Dict, None] = None
|
|
||||||
) -> Union[List, Dict]:
|
|
||||||
"""Load JSON data from a file and return as dict or list.
|
"""Load JSON data from a file and return as dict or list.
|
||||||
|
|
||||||
Defaults to returning empty dict if file is not found.
|
Defaults to returning empty dict if file is not found.
|
||||||
|
@ -44,10 +44,10 @@ def load_json(
|
||||||
|
|
||||||
def save_json(
|
def save_json(
|
||||||
filename: str,
|
filename: str,
|
||||||
data: Union[List, Dict],
|
data: list | dict,
|
||||||
private: bool = False,
|
private: bool = False,
|
||||||
*,
|
*,
|
||||||
encoder: Optional[Type[json.JSONEncoder]] = None,
|
encoder: type[json.JSONEncoder] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save JSON data to a file.
|
"""Save JSON data to a file.
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def save_json(
|
||||||
_LOGGER.error("JSON replacement cleanup failed: %s", err)
|
_LOGGER.error("JSON replacement cleanup failed: %s", err)
|
||||||
|
|
||||||
|
|
||||||
def format_unserializable_data(data: Dict[str, Any]) -> str:
|
def format_unserializable_data(data: dict[str, Any]) -> str:
|
||||||
"""Format output of find_paths in a friendly way.
|
"""Format output of find_paths in a friendly way.
|
||||||
|
|
||||||
Format is comma separated: <path>=<value>(<type>)
|
Format is comma separated: <path>=<value>(<type>)
|
||||||
|
@ -95,7 +95,7 @@ def format_unserializable_data(data: Dict[str, Any]) -> str:
|
||||||
|
|
||||||
def find_paths_unserializable_data(
|
def find_paths_unserializable_data(
|
||||||
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
|
bad_data: Any, *, dump: Callable[[Any], str] = json.dumps
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Find the paths to unserializable data.
|
"""Find the paths to unserializable data.
|
||||||
|
|
||||||
This method is slow! Only use for error handling.
|
This method is slow! Only use for error handling.
|
||||||
|
|
|
@ -3,10 +3,12 @@ Module with location helpers.
|
||||||
|
|
||||||
detect_location_info and elevation are mocked by default during tests.
|
detect_location_info and elevation are mocked by default during tests.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
@ -47,7 +49,7 @@ LocationInfo = collections.namedtuple(
|
||||||
|
|
||||||
async def async_detect_location_info(
|
async def async_detect_location_info(
|
||||||
session: aiohttp.ClientSession,
|
session: aiohttp.ClientSession,
|
||||||
) -> Optional[LocationInfo]:
|
) -> LocationInfo | None:
|
||||||
"""Detect location information."""
|
"""Detect location information."""
|
||||||
data = await _get_ipapi(session)
|
data = await _get_ipapi(session)
|
||||||
|
|
||||||
|
@ -63,8 +65,8 @@ async def async_detect_location_info(
|
||||||
|
|
||||||
|
|
||||||
def distance(
|
def distance(
|
||||||
lat1: Optional[float], lon1: Optional[float], lat2: float, lon2: float
|
lat1: float | None, lon1: float | None, lat2: float, lon2: float
|
||||||
) -> Optional[float]:
|
) -> float | None:
|
||||||
"""Calculate the distance in meters between two points.
|
"""Calculate the distance in meters between two points.
|
||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
|
@ -81,8 +83,8 @@ def distance(
|
||||||
# Source: https://github.com/maurycyp/vincenty
|
# Source: https://github.com/maurycyp/vincenty
|
||||||
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
|
# License: https://github.com/maurycyp/vincenty/blob/master/LICENSE
|
||||||
def vincenty(
|
def vincenty(
|
||||||
point1: Tuple[float, float], point2: Tuple[float, float], miles: bool = False
|
point1: tuple[float, float], point2: tuple[float, float], miles: bool = False
|
||||||
) -> Optional[float]:
|
) -> float | None:
|
||||||
"""
|
"""
|
||||||
Vincenty formula (inverse method) to calculate the distance.
|
Vincenty formula (inverse method) to calculate the distance.
|
||||||
|
|
||||||
|
@ -162,7 +164,7 @@ def vincenty(
|
||||||
return round(s, 6)
|
return round(s, 6)
|
||||||
|
|
||||||
|
|
||||||
async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
|
async def _get_ipapi(session: aiohttp.ClientSession) -> dict[str, Any] | None:
|
||||||
"""Query ipapi.co for location data."""
|
"""Query ipapi.co for location data."""
|
||||||
try:
|
try:
|
||||||
resp = await session.get(IPAPI, timeout=5)
|
resp = await session.get(IPAPI, timeout=5)
|
||||||
|
@ -192,7 +194,7 @@ async def _get_ipapi(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _get_ip_api(session: aiohttp.ClientSession) -> Optional[Dict[str, Any]]:
|
async def _get_ip_api(session: aiohttp.ClientSession) -> dict[str, Any] | None:
|
||||||
"""Query ip-api.com for location data."""
|
"""Query ip-api.com for location data."""
|
||||||
try:
|
try:
|
||||||
resp = await session.get(IP_API, timeout=5)
|
resp = await session.get(IP_API, timeout=5)
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Logging utilities."""
|
"""Logging utilities."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -6,7 +8,7 @@ import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import queue
|
import queue
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Awaitable, Callable, Coroutine, Union, cast, overload
|
from typing import Any, Awaitable, Callable, Coroutine, cast, overload
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
|
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
@ -115,7 +117,7 @@ def catch_log_exception(
|
||||||
|
|
||||||
def catch_log_exception(
|
def catch_log_exception(
|
||||||
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
|
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
|
||||||
) -> Union[Callable[..., None], Callable[..., Awaitable[None]]]:
|
) -> Callable[..., None] | Callable[..., Awaitable[None]]:
|
||||||
"""Decorate a callback to catch and log exceptions."""
|
"""Decorate a callback to catch and log exceptions."""
|
||||||
|
|
||||||
# Check for partials to properly determine if coroutine function
|
# Check for partials to properly determine if coroutine function
|
||||||
|
@ -123,7 +125,7 @@ def catch_log_exception(
|
||||||
while isinstance(check_func, partial):
|
while isinstance(check_func, partial):
|
||||||
check_func = check_func.func
|
check_func = check_func.func
|
||||||
|
|
||||||
wrapper_func: Union[Callable[..., None], Callable[..., Awaitable[None]]]
|
wrapper_func: Callable[..., None] | Callable[..., Awaitable[None]]
|
||||||
if asyncio.iscoroutinefunction(check_func):
|
if asyncio.iscoroutinefunction(check_func):
|
||||||
async_func = cast(Callable[..., Awaitable[None]], func)
|
async_func = cast(Callable[..., Awaitable[None]], func)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Network utilities."""
|
"""Network utilities."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
|
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import yarl
|
import yarl
|
||||||
|
|
||||||
|
@ -23,22 +24,22 @@ PRIVATE_NETWORKS = (
|
||||||
LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16")
|
LINK_LOCAL_NETWORK = ip_network("169.254.0.0/16")
|
||||||
|
|
||||||
|
|
||||||
def is_loopback(address: Union[IPv4Address, IPv6Address]) -> bool:
|
def is_loopback(address: IPv4Address | IPv6Address) -> bool:
|
||||||
"""Check if an address is a loopback address."""
|
"""Check if an address is a loopback address."""
|
||||||
return any(address in network for network in LOOPBACK_NETWORKS)
|
return any(address in network for network in LOOPBACK_NETWORKS)
|
||||||
|
|
||||||
|
|
||||||
def is_private(address: Union[IPv4Address, IPv6Address]) -> bool:
|
def is_private(address: IPv4Address | IPv6Address) -> bool:
|
||||||
"""Check if an address is a private address."""
|
"""Check if an address is a private address."""
|
||||||
return any(address in network for network in PRIVATE_NETWORKS)
|
return any(address in network for network in PRIVATE_NETWORKS)
|
||||||
|
|
||||||
|
|
||||||
def is_link_local(address: Union[IPv4Address, IPv6Address]) -> bool:
|
def is_link_local(address: IPv4Address | IPv6Address) -> bool:
|
||||||
"""Check if an address is link local."""
|
"""Check if an address is link local."""
|
||||||
return address in LINK_LOCAL_NETWORK
|
return address in LINK_LOCAL_NETWORK
|
||||||
|
|
||||||
|
|
||||||
def is_local(address: Union[IPv4Address, IPv6Address]) -> bool:
|
def is_local(address: IPv4Address | IPv6Address) -> bool:
|
||||||
"""Check if an address is loopback or private."""
|
"""Check if an address is loopback or private."""
|
||||||
return is_loopback(address) or is_private(address)
|
return is_loopback(address) or is_private(address)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Helpers to install PyPi packages."""
|
"""Helpers to install PyPi packages."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
import logging
|
import logging
|
||||||
|
@ -6,7 +8,6 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import PIPE, Popen
|
from subprocess import PIPE, Popen
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
@ -59,10 +60,10 @@ def is_installed(package: str) -> bool:
|
||||||
def install_package(
|
def install_package(
|
||||||
package: str,
|
package: str,
|
||||||
upgrade: bool = True,
|
upgrade: bool = True,
|
||||||
target: Optional[str] = None,
|
target: str | None = None,
|
||||||
constraints: Optional[str] = None,
|
constraints: str | None = None,
|
||||||
find_links: Optional[str] = None,
|
find_links: str | None = None,
|
||||||
no_cache_dir: Optional[bool] = False,
|
no_cache_dir: bool | None = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Install a package on PyPi. Accepts pip compatible package strings.
|
"""Install a package on PyPi. Accepts pip compatible package strings.
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
"""Percentage util functions."""
|
"""Percentage util functions."""
|
||||||
|
from __future__ import annotations
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
|
def ordered_list_item_to_percentage(ordered_list: list[str], item: str) -> int:
|
||||||
"""Determine the percentage of an item in an ordered list.
|
"""Determine the percentage of an item in an ordered list.
|
||||||
|
|
||||||
When using this utility for fan speeds, do not include "off"
|
When using this utility for fan speeds, do not include "off"
|
||||||
|
@ -26,7 +25,7 @@ def ordered_list_item_to_percentage(ordered_list: List[str], item: str) -> int:
|
||||||
return (list_position * 100) // list_len
|
return (list_position * 100) // list_len
|
||||||
|
|
||||||
|
|
||||||
def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) -> str:
|
def percentage_to_ordered_list_item(ordered_list: list[str], percentage: int) -> str:
|
||||||
"""Find the item that most closely matches the percentage in an ordered list.
|
"""Find the item that most closely matches the percentage in an ordered list.
|
||||||
|
|
||||||
When using this utility for fan speeds, do not include "off"
|
When using this utility for fan speeds, do not include "off"
|
||||||
|
@ -54,7 +53,7 @@ def percentage_to_ordered_list_item(ordered_list: List[str], percentage: int) ->
|
||||||
|
|
||||||
|
|
||||||
def ranged_value_to_percentage(
|
def ranged_value_to_percentage(
|
||||||
low_high_range: Tuple[float, float], value: float
|
low_high_range: tuple[float, float], value: float
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Given a range of low and high values convert a single value to a percentage.
|
"""Given a range of low and high values convert a single value to a percentage.
|
||||||
|
|
||||||
|
@ -71,7 +70,7 @@ def ranged_value_to_percentage(
|
||||||
|
|
||||||
|
|
||||||
def percentage_to_ranged_value(
|
def percentage_to_ranged_value(
|
||||||
low_high_range: Tuple[float, float], percentage: int
|
low_high_range: tuple[float, float], percentage: int
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Given a range of low and high values convert a percentage to a single value.
|
"""Given a range of low and high values convert a percentage to a single value.
|
||||||
|
|
||||||
|
@ -87,11 +86,11 @@ def percentage_to_ranged_value(
|
||||||
return states_in_range(low_high_range) * percentage / 100
|
return states_in_range(low_high_range) * percentage / 100
|
||||||
|
|
||||||
|
|
||||||
def states_in_range(low_high_range: Tuple[float, float]) -> float:
|
def states_in_range(low_high_range: tuple[float, float]) -> float:
|
||||||
"""Given a range of low and high values return how many states exist."""
|
"""Given a range of low and high values return how many states exist."""
|
||||||
return low_high_range[1] - low_high_range[0] + 1
|
return low_high_range[1] - low_high_range[0] + 1
|
||||||
|
|
||||||
|
|
||||||
def int_states_in_range(low_high_range: Tuple[float, float]) -> int:
|
def int_states_in_range(low_high_range: tuple[float, float]) -> int:
|
||||||
"""Given a range of low and high values return how many integer states exist."""
|
"""Given a range of low and high values return how many integer states exist."""
|
||||||
return int(states_in_range(low_high_range))
|
return int(states_in_range(low_high_range))
|
||||||
|
|
|
@ -2,18 +2,18 @@
|
||||||
|
|
||||||
Can only be used by integrations that have pillow in their requirements.
|
Can only be used by integrations that have pillow in their requirements.
|
||||||
"""
|
"""
|
||||||
from typing import Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
|
|
||||||
|
|
||||||
def draw_box(
|
def draw_box(
|
||||||
draw: ImageDraw,
|
draw: ImageDraw,
|
||||||
box: Tuple[float, float, float, float],
|
box: tuple[float, float, float, float],
|
||||||
img_width: int,
|
img_width: int,
|
||||||
img_height: int,
|
img_height: int,
|
||||||
text: str = "",
|
text: str = "",
|
||||||
color: Tuple[int, int, int] = (255, 255, 0),
|
color: tuple[int, int, int] = (255, 255, 0),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Draw a bounding box on and image.
|
Draw a bounding box on and image.
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
"""ruamel.yaml utility functions."""
|
"""ruamel.yaml utility functions."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result
|
from os import O_CREAT, O_TRUNC, O_WRONLY, stat_result
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
from ruamel.yaml import YAML # type: ignore
|
from ruamel.yaml import YAML # type: ignore
|
||||||
|
@ -22,7 +24,7 @@ JSON_TYPE = Union[List, Dict, str] # pylint: disable=invalid-name
|
||||||
class ExtSafeConstructor(SafeConstructor):
|
class ExtSafeConstructor(SafeConstructor):
|
||||||
"""Extended SafeConstructor."""
|
"""Extended SafeConstructor."""
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedYamlError(HomeAssistantError):
|
class UnsupportedYamlError(HomeAssistantError):
|
||||||
|
@ -77,7 +79,7 @@ def yaml_to_object(data: str) -> JSON_TYPE:
|
||||||
"""Create object from yaml string."""
|
"""Create object from yaml string."""
|
||||||
yaml = YAML(typ="rt")
|
yaml = YAML(typ="rt")
|
||||||
try:
|
try:
|
||||||
result: Union[List, Dict, str] = yaml.load(data)
|
result: list | dict | str = yaml.load(data)
|
||||||
return result
|
return result
|
||||||
except YAMLError as exc:
|
except YAMLError as exc:
|
||||||
_LOGGER.error("YAML error: %s", exc)
|
_LOGGER.error("YAML error: %s", exc)
|
||||||
|
|
|
@ -8,7 +8,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import enum
|
import enum
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import Any
|
||||||
|
|
||||||
from .async_ import run_callback_threadsafe
|
from .async_ import run_callback_threadsafe
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ class _GlobalFreezeContext:
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._exit()
|
self._exit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -51,10 +51,10 @@ class _GlobalFreezeContext:
|
||||||
|
|
||||||
def __exit__( # pylint: disable=useless-return
|
def __exit__( # pylint: disable=useless-return
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._loop.call_soon_threadsafe(self._exit)
|
self._loop.call_soon_threadsafe(self._exit)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -106,10 +106,10 @@ class _ZoneFreezeContext:
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._exit()
|
self._exit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -119,10 +119,10 @@ class _ZoneFreezeContext:
|
||||||
|
|
||||||
def __exit__( # pylint: disable=useless-return
|
def __exit__( # pylint: disable=useless-return
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._loop.call_soon_threadsafe(self._exit)
|
self._loop.call_soon_threadsafe(self._exit)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -155,8 +155,8 @@ class _GlobalTaskContext:
|
||||||
self._manager: TimeoutManager = manager
|
self._manager: TimeoutManager = manager
|
||||||
self._task: asyncio.Task[Any] = task
|
self._task: asyncio.Task[Any] = task
|
||||||
self._time_left: float = timeout
|
self._time_left: float = timeout
|
||||||
self._expiration_time: Optional[float] = None
|
self._expiration_time: float | None = None
|
||||||
self._timeout_handler: Optional[asyncio.Handle] = None
|
self._timeout_handler: asyncio.Handle | None = None
|
||||||
self._wait_zone: asyncio.Event = asyncio.Event()
|
self._wait_zone: asyncio.Event = asyncio.Event()
|
||||||
self._state: _State = _State.INIT
|
self._state: _State = _State.INIT
|
||||||
self._cool_down: float = cool_down
|
self._cool_down: float = cool_down
|
||||||
|
@ -169,10 +169,10 @@ class _GlobalTaskContext:
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._stop_timer()
|
self._stop_timer()
|
||||||
self._manager.global_tasks.remove(self)
|
self._manager.global_tasks.remove(self)
|
||||||
|
|
||||||
|
@ -263,8 +263,8 @@ class _ZoneTaskContext:
|
||||||
self._task: asyncio.Task[Any] = task
|
self._task: asyncio.Task[Any] = task
|
||||||
self._state: _State = _State.INIT
|
self._state: _State = _State.INIT
|
||||||
self._time_left: float = timeout
|
self._time_left: float = timeout
|
||||||
self._expiration_time: Optional[float] = None
|
self._expiration_time: float | None = None
|
||||||
self._timeout_handler: Optional[asyncio.Handle] = None
|
self._timeout_handler: asyncio.Handle | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> _State:
|
def state(self) -> _State:
|
||||||
|
@ -283,10 +283,10 @@ class _ZoneTaskContext:
|
||||||
|
|
||||||
async def __aexit__(
|
async def __aexit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Type[BaseException],
|
exc_type: type[BaseException],
|
||||||
exc_val: BaseException,
|
exc_val: BaseException,
|
||||||
exc_tb: TracebackType,
|
exc_tb: TracebackType,
|
||||||
) -> Optional[bool]:
|
) -> bool | None:
|
||||||
self._zone.exit_task(self)
|
self._zone.exit_task(self)
|
||||||
self._stop_timer()
|
self._stop_timer()
|
||||||
|
|
||||||
|
@ -344,8 +344,8 @@ class _ZoneTimeoutManager:
|
||||||
"""Initialize internal timeout context manager."""
|
"""Initialize internal timeout context manager."""
|
||||||
self._manager: TimeoutManager = manager
|
self._manager: TimeoutManager = manager
|
||||||
self._zone: str = zone
|
self._zone: str = zone
|
||||||
self._tasks: List[_ZoneTaskContext] = []
|
self._tasks: list[_ZoneTaskContext] = []
|
||||||
self._freezes: List[_ZoneFreezeContext] = []
|
self._freezes: list[_ZoneFreezeContext] = []
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Representation of a zone."""
|
"""Representation of a zone."""
|
||||||
|
@ -418,9 +418,9 @@ class TimeoutManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize TimeoutManager."""
|
"""Initialize TimeoutManager."""
|
||||||
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
||||||
self._zones: Dict[str, _ZoneTimeoutManager] = {}
|
self._zones: dict[str, _ZoneTimeoutManager] = {}
|
||||||
self._globals: List[_GlobalTaskContext] = []
|
self._globals: list[_GlobalTaskContext] = []
|
||||||
self._freezes: List[_GlobalFreezeContext] = []
|
self._freezes: list[_GlobalFreezeContext] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zones_done(self) -> bool:
|
def zones_done(self) -> bool:
|
||||||
|
@ -433,17 +433,17 @@ class TimeoutManager:
|
||||||
return not self._freezes
|
return not self._freezes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zones(self) -> Dict[str, _ZoneTimeoutManager]:
|
def zones(self) -> dict[str, _ZoneTimeoutManager]:
|
||||||
"""Return all Zones."""
|
"""Return all Zones."""
|
||||||
return self._zones
|
return self._zones
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_tasks(self) -> List[_GlobalTaskContext]:
|
def global_tasks(self) -> list[_GlobalTaskContext]:
|
||||||
"""Return all global Tasks."""
|
"""Return all global Tasks."""
|
||||||
return self._globals
|
return self._globals
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_freezes(self) -> List[_GlobalFreezeContext]:
|
def global_freezes(self) -> list[_GlobalFreezeContext]:
|
||||||
"""Return all global Freezes."""
|
"""Return all global Freezes."""
|
||||||
return self._freezes
|
return self._freezes
|
||||||
|
|
||||||
|
@ -459,12 +459,12 @@ class TimeoutManager:
|
||||||
|
|
||||||
def async_timeout(
|
def async_timeout(
|
||||||
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
|
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
|
||||||
) -> Union[_ZoneTaskContext, _GlobalTaskContext]:
|
) -> _ZoneTaskContext | _GlobalTaskContext:
|
||||||
"""Timeout based on a zone.
|
"""Timeout based on a zone.
|
||||||
|
|
||||||
For using as Async Context Manager.
|
For using as Async Context Manager.
|
||||||
"""
|
"""
|
||||||
current_task: Optional[asyncio.Task[Any]] = asyncio.current_task()
|
current_task: asyncio.Task[Any] | None = asyncio.current_task()
|
||||||
assert current_task
|
assert current_task
|
||||||
|
|
||||||
# Global Zone
|
# Global Zone
|
||||||
|
@ -483,7 +483,7 @@ class TimeoutManager:
|
||||||
|
|
||||||
def async_freeze(
|
def async_freeze(
|
||||||
self, zone_name: str = ZONE_GLOBAL
|
self, zone_name: str = ZONE_GLOBAL
|
||||||
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
|
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
||||||
"""Freeze all timer until job is done.
|
"""Freeze all timer until job is done.
|
||||||
|
|
||||||
For using as Async Context Manager.
|
For using as Async Context Manager.
|
||||||
|
@ -502,7 +502,7 @@ class TimeoutManager:
|
||||||
|
|
||||||
def freeze(
|
def freeze(
|
||||||
self, zone_name: str = ZONE_GLOBAL
|
self, zone_name: str = ZONE_GLOBAL
|
||||||
) -> Union[_ZoneFreezeContext, _GlobalFreezeContext]:
|
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
||||||
"""Freeze all timer until job is done.
|
"""Freeze all timer until job is done.
|
||||||
|
|
||||||
For using as Context Manager.
|
For using as Context Manager.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Unit system helper class and methods."""
|
"""Unit system helper class and methods."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_UNIT_SYSTEM_IMPERIAL,
|
CONF_UNIT_SYSTEM_IMPERIAL,
|
||||||
|
@ -109,7 +110,7 @@ class UnitSystem:
|
||||||
|
|
||||||
return temperature_util.convert(temperature, from_unit, self.temperature_unit)
|
return temperature_util.convert(temperature, from_unit, self.temperature_unit)
|
||||||
|
|
||||||
def length(self, length: Optional[float], from_unit: str) -> float:
|
def length(self, length: float | None, from_unit: str) -> float:
|
||||||
"""Convert the given length to this unit system."""
|
"""Convert the given length to this unit system."""
|
||||||
if not isinstance(length, Number):
|
if not isinstance(length, Number):
|
||||||
raise TypeError(f"{length!s} is not a numeric value.")
|
raise TypeError(f"{length!s} is not a numeric value.")
|
||||||
|
@ -119,7 +120,7 @@ class UnitSystem:
|
||||||
length, from_unit, self.length_unit
|
length, from_unit, self.length_unit
|
||||||
)
|
)
|
||||||
|
|
||||||
def pressure(self, pressure: Optional[float], from_unit: str) -> float:
|
def pressure(self, pressure: float | None, from_unit: str) -> float:
|
||||||
"""Convert the given pressure to this unit system."""
|
"""Convert the given pressure to this unit system."""
|
||||||
if not isinstance(pressure, Number):
|
if not isinstance(pressure, Number):
|
||||||
raise TypeError(f"{pressure!s} is not a numeric value.")
|
raise TypeError(f"{pressure!s} is not a numeric value.")
|
||||||
|
@ -129,7 +130,7 @@ class UnitSystem:
|
||||||
pressure, from_unit, self.pressure_unit
|
pressure, from_unit, self.pressure_unit
|
||||||
)
|
)
|
||||||
|
|
||||||
def volume(self, volume: Optional[float], from_unit: str) -> float:
|
def volume(self, volume: float | None, from_unit: str) -> float:
|
||||||
"""Convert the given volume to this unit system."""
|
"""Convert the given volume to this unit system."""
|
||||||
if not isinstance(volume, Number):
|
if not isinstance(volume, Number):
|
||||||
raise TypeError(f"{volume!s} is not a numeric value.")
|
raise TypeError(f"{volume!s} is not a numeric value.")
|
||||||
|
@ -137,7 +138,7 @@ class UnitSystem:
|
||||||
# type ignore: https://github.com/python/mypy/issues/7207
|
# type ignore: https://github.com/python/mypy/issues/7207
|
||||||
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
|
return volume_util.convert(volume, from_unit, self.volume_unit) # type: ignore
|
||||||
|
|
||||||
def as_dict(self) -> Dict[str, str]:
|
def as_dict(self) -> dict[str, str]:
|
||||||
"""Convert the unit system to a dictionary."""
|
"""Convert the unit system to a dictionary."""
|
||||||
return {
|
return {
|
||||||
LENGTH: self.length_unit,
|
LENGTH: self.length_unit,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Deal with YAML input."""
|
"""Deal with YAML input."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Set
|
from typing import Any
|
||||||
|
|
||||||
from .objects import Input
|
from .objects import Input
|
||||||
|
|
||||||
|
@ -14,14 +15,14 @@ class UndefinedSubstitution(Exception):
|
||||||
self.input = input
|
self.input = input
|
||||||
|
|
||||||
|
|
||||||
def extract_inputs(obj: Any) -> Set[str]:
|
def extract_inputs(obj: Any) -> set[str]:
|
||||||
"""Extract input from a structure."""
|
"""Extract input from a structure."""
|
||||||
found: Set[str] = set()
|
found: set[str] = set()
|
||||||
_extract_inputs(obj, found)
|
_extract_inputs(obj, found)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
def _extract_inputs(obj: Any, found: Set[str]) -> None:
|
def _extract_inputs(obj: Any, found: set[str]) -> None:
|
||||||
"""Extract input from a structure."""
|
"""Extract input from a structure."""
|
||||||
if isinstance(obj, Input):
|
if isinstance(obj, Input):
|
||||||
found.add(obj.name)
|
found.add(obj.name)
|
||||||
|
@ -38,7 +39,7 @@ def _extract_inputs(obj: Any, found: Set[str]) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def substitute(obj: Any, substitutions: Dict[str, Any]) -> Any:
|
def substitute(obj: Any, substitutions: dict[str, Any]) -> Any:
|
||||||
"""Substitute values."""
|
"""Substitute values."""
|
||||||
if isinstance(obj, Input):
|
if isinstance(obj, Input):
|
||||||
if obj.name not in substitutions:
|
if obj.name not in substitutions:
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
"""Custom loader."""
|
"""Custom loader."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union, overload
|
from typing import Any, Dict, Iterator, List, TextIO, TypeVar, Union, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -27,7 +29,7 @@ class Secrets:
|
||||||
def __init__(self, config_dir: Path):
|
def __init__(self, config_dir: Path):
|
||||||
"""Initialize secrets."""
|
"""Initialize secrets."""
|
||||||
self.config_dir = config_dir
|
self.config_dir = config_dir
|
||||||
self._cache: Dict[Path, Dict[str, str]] = {}
|
self._cache: dict[Path, dict[str, str]] = {}
|
||||||
|
|
||||||
def get(self, requester_path: str, secret: str) -> str:
|
def get(self, requester_path: str, secret: str) -> str:
|
||||||
"""Return the value of a secret."""
|
"""Return the value of a secret."""
|
||||||
|
@ -55,7 +57,7 @@ class Secrets:
|
||||||
|
|
||||||
raise HomeAssistantError(f"Secret {secret} not defined")
|
raise HomeAssistantError(f"Secret {secret} not defined")
|
||||||
|
|
||||||
def _load_secret_yaml(self, secret_dir: Path) -> Dict[str, str]:
|
def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
|
||||||
"""Load the secrets yaml from path."""
|
"""Load the secrets yaml from path."""
|
||||||
secret_path = secret_dir / SECRET_YAML
|
secret_path = secret_dir / SECRET_YAML
|
||||||
|
|
||||||
|
@ -90,7 +92,7 @@ class Secrets:
|
||||||
class SafeLineLoader(yaml.SafeLoader):
|
class SafeLineLoader(yaml.SafeLoader):
|
||||||
"""Loader class that keeps track of line numbers."""
|
"""Loader class that keeps track of line numbers."""
|
||||||
|
|
||||||
def __init__(self, stream: Any, secrets: Optional[Secrets] = None) -> None:
|
def __init__(self, stream: Any, secrets: Secrets | None = None) -> None:
|
||||||
"""Initialize a safe line loader."""
|
"""Initialize a safe line loader."""
|
||||||
super().__init__(stream)
|
super().__init__(stream)
|
||||||
self.secrets = secrets
|
self.secrets = secrets
|
||||||
|
@ -103,7 +105,7 @@ class SafeLineLoader(yaml.SafeLoader):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
|
def load_yaml(fname: str, secrets: Secrets | None = None) -> JSON_TYPE:
|
||||||
"""Load a YAML file."""
|
"""Load a YAML file."""
|
||||||
try:
|
try:
|
||||||
with open(fname, encoding="utf-8") as conf_file:
|
with open(fname, encoding="utf-8") as conf_file:
|
||||||
|
@ -113,9 +115,7 @@ def load_yaml(fname: str, secrets: Optional[Secrets] = None) -> JSON_TYPE:
|
||||||
raise HomeAssistantError(exc) from exc
|
raise HomeAssistantError(exc) from exc
|
||||||
|
|
||||||
|
|
||||||
def parse_yaml(
|
def parse_yaml(content: str | TextIO, secrets: Secrets | None = None) -> JSON_TYPE:
|
||||||
content: Union[str, TextIO], secrets: Optional[Secrets] = None
|
|
||||||
) -> JSON_TYPE:
|
|
||||||
"""Load a YAML file."""
|
"""Load a YAML file."""
|
||||||
try:
|
try:
|
||||||
# If configuration file is empty YAML returns None
|
# If configuration file is empty YAML returns None
|
||||||
|
@ -131,14 +131,14 @@ def parse_yaml(
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _add_reference(
|
def _add_reference(
|
||||||
obj: Union[list, NodeListClass], loader: SafeLineLoader, node: yaml.nodes.Node
|
obj: list | NodeListClass, loader: SafeLineLoader, node: yaml.nodes.Node
|
||||||
) -> NodeListClass:
|
) -> NodeListClass:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _add_reference(
|
def _add_reference(
|
||||||
obj: Union[str, NodeStrClass], loader: SafeLineLoader, node: yaml.nodes.Node
|
obj: str | NodeStrClass, loader: SafeLineLoader, node: yaml.nodes.Node
|
||||||
) -> NodeStrClass:
|
) -> NodeStrClass:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ def _include_dir_merge_named_yaml(
|
||||||
|
|
||||||
def _include_dir_list_yaml(
|
def _include_dir_list_yaml(
|
||||||
loader: SafeLineLoader, node: yaml.nodes.Node
|
loader: SafeLineLoader, node: yaml.nodes.Node
|
||||||
) -> List[JSON_TYPE]:
|
) -> list[JSON_TYPE]:
|
||||||
"""Load multiple files from directory as a list."""
|
"""Load multiple files from directory as a list."""
|
||||||
loc = os.path.join(os.path.dirname(loader.name), node.value)
|
loc = os.path.join(os.path.dirname(loader.name), node.value)
|
||||||
return [
|
return [
|
||||||
|
@ -238,7 +238,7 @@ def _include_dir_merge_list_yaml(
|
||||||
) -> JSON_TYPE:
|
) -> JSON_TYPE:
|
||||||
"""Load multiple files from directory as a merged list."""
|
"""Load multiple files from directory as a merged list."""
|
||||||
loc: str = os.path.join(os.path.dirname(loader.name), node.value)
|
loc: str = os.path.join(os.path.dirname(loader.name), node.value)
|
||||||
merged_list: List[JSON_TYPE] = []
|
merged_list: list[JSON_TYPE] = []
|
||||||
for fname in _find_files(loc, "*.yaml"):
|
for fname in _find_files(loc, "*.yaml"):
|
||||||
if os.path.basename(fname) == SECRET_YAML:
|
if os.path.basename(fname) == SECRET_YAML:
|
||||||
continue
|
continue
|
||||||
|
@ -253,7 +253,7 @@ def _ordered_dict(loader: SafeLineLoader, node: yaml.nodes.MappingNode) -> Order
|
||||||
loader.flatten_mapping(node)
|
loader.flatten_mapping(node)
|
||||||
nodes = loader.construct_pairs(node)
|
nodes = loader.construct_pairs(node)
|
||||||
|
|
||||||
seen: Dict = {}
|
seen: dict = {}
|
||||||
for (key, _), (child_node, _) in zip(nodes, node.value):
|
for (key, _), (child_node, _) in zip(nodes, node.value):
|
||||||
line = child_node.start_mark.line
|
line = child_node.start_mark.line
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue