Use new syntax for TypeVar defaults (#135780)

pull/135830/head
Marc Mueller 2025-01-17 09:12:52 +01:00 committed by GitHub
parent 6aed2dcc0f
commit 46b17b539c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 35 additions and 62 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import types
from typing import Any, Generic, TypeVar
from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -34,12 +34,6 @@ DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
_LOGGER = logging.getLogger(__name__)
_MultiFactorAuthModuleT = TypeVar(
"_MultiFactorAuthModuleT",
bound="MultiFactorAuthModule",
default="MultiFactorAuthModule",
)
class MultiFactorAuthModule:
"""Multi-factor Auth Module of validation function."""
@ -101,7 +95,9 @@ class MultiFactorAuthModule:
raise NotImplementedError
class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
class SetupFlow[_MultiFactorAuthModuleT: MultiFactorAuthModule = MultiFactorAuthModule](
data_entry_flow.FlowHandler
):
"""Handler for the setup flow."""
def __init__(

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Mapping
import logging
import types
from typing import Any, Generic, TypeVar
from typing import Any
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -46,8 +46,6 @@ AUTH_PROVIDER_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA,
)
_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")
class AuthProvider:
"""Provider of user authentication."""
@ -194,9 +192,8 @@ async def load_auth_provider_module(
return module
class LoginFlow(
class LoginFlow[_AuthProviderT: AuthProvider = AuthProvider](
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
Generic[_AuthProviderT],
):
"""Handler for the login flow."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.update_coordinator import (
@ -18,12 +18,6 @@ if TYPE_CHECKING:
from . import BluetoothChange, BluetoothScanningMode, BluetoothServiceInfoBleak
_PassiveBluetoothDataUpdateCoordinatorT = TypeVar(
"_PassiveBluetoothDataUpdateCoordinatorT",
bound="PassiveBluetoothDataUpdateCoordinator",
default="PassiveBluetoothDataUpdateCoordinator",
)
class PassiveBluetoothDataUpdateCoordinator(
BasePassiveBluetoothCoordinator, BaseDataUpdateCoordinatorProtocol
@ -96,7 +90,9 @@ class PassiveBluetoothDataUpdateCoordinator(
self.async_update_listeners()
class PassiveBluetoothCoordinatorEntity( # pylint: disable=hass-enforce-class-module
class PassiveBluetoothCoordinatorEntity[
_PassiveBluetoothDataUpdateCoordinatorT: PassiveBluetoothDataUpdateCoordinator = PassiveBluetoothDataUpdateCoordinator
]( # pylint: disable=hass-enforce-class-module
BaseCoordinatorEntity[_PassiveBluetoothDataUpdateCoordinatorT]
):
"""A class for entities using DataUpdateCoordinator."""

View File

@ -3,7 +3,6 @@
from contextlib import suppress
from functools import partial
import logging
from typing import Generic, TypeVar
import broadlink as blk
from broadlink.exceptions import (
@ -30,8 +29,6 @@ from homeassistant.helpers import device_registry as dr
from .const import DEFAULT_PORT, DOMAIN, DOMAINS_AND_TYPES
from .updater import BroadlinkUpdateManager, get_update_manager
_ApiT = TypeVar("_ApiT", bound=blk.Device, default=blk.Device)
_LOGGER = logging.getLogger(__name__)
@ -40,7 +37,7 @@ def get_domains(device_type: str) -> set[Platform]:
return {d for d, t in DOMAINS_AND_TYPES.items() if device_type in t}
class BroadlinkDevice(Generic[_ApiT]):
class BroadlinkDevice[_ApiT: blk.Device = blk.Device]:
"""Manages a Broadlink device."""
api: _ApiT

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable, Coroutine
from dataclasses import dataclass
from typing import Any, ParamSpec, TypeVar
from typing import Any
from reolink_aio.exceptions import (
ApiError,
@ -87,17 +87,13 @@ def get_device_uid_and_ch(
return (device_uid, ch, is_chime)
T = TypeVar("T")
P = ParamSpec("P")
# Decorators
def raise_translated_error(
func: Callable[P, Awaitable[T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
def raise_translated_error[**P, R](
func: Callable[P, Awaitable[R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
"""Wrap a reolink-aio function to translate any potential errors."""
async def decorator_raise_translated_error(*args: P.args, **kwargs: P.kwargs) -> T:
async def decorator_raise_translated_error(*args: P.args, **kwargs: P.kwargs) -> R:
"""Try a reolink-aio function and translate any potential errors."""
try:
return await func(*args, **kwargs)

View File

@ -22,7 +22,7 @@ from functools import cache
import logging
from random import randint
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, cast
from typing import TYPE_CHECKING, Any, Self, cast
from async_interrupt import interrupt
from propcache import cached_property
@ -136,8 +136,6 @@ DISCOVERY_COOLDOWN = 1
ISSUE_UNIQUE_ID_COLLISION = "config_entry_unique_id_collision"
UNIQUE_ID_COLLISION_TITLE_LIMIT = 5
_DataT = TypeVar("_DataT", default=Any)
class ConfigEntryState(Enum):
"""Config entry state."""
@ -312,7 +310,7 @@ def _validate_item(*, disabled_by: ConfigEntryDisabler | Any | None = None) -> N
)
class ConfigEntry(Generic[_DataT]):
class ConfigEntry[_DataT = Any]:
"""Hold a configuration entry."""
entry_id: str

View File

@ -11,7 +11,7 @@ from hashlib import md5
from itertools import groupby
import logging
from operator import attrgetter
from typing import Any, Generic, TypedDict, TypeVar
from typing import Any, TypedDict
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -36,8 +36,6 @@ CHANGE_ADDED = "added"
CHANGE_UPDATED = "updated"
CHANGE_REMOVED = "removed"
_EntityT = TypeVar("_EntityT", bound=Entity, default=Entity)
@dataclass(slots=True)
class CollectionChange:
@ -447,7 +445,7 @@ _GROUP_BY_KEY = attrgetter("change_type")
@dataclass(slots=True, frozen=True)
class _CollectionLifeCycle(Generic[_EntityT]):
class _CollectionLifeCycle[_EntityT: Entity = Entity]:
"""Life cycle for a collection of entities."""
domain: str
@ -522,7 +520,7 @@ class _CollectionLifeCycle(Generic[_EntityT]):
@callback
def sync_entity_lifecycle(
def sync_entity_lifecycle[_EntityT: Entity = Entity](
hass: HomeAssistant,
domain: str,
platform: str,

View File

@ -7,7 +7,7 @@ from collections.abc import Callable, Iterable
from datetime import timedelta
import logging
from types import ModuleType
from typing import Any, Generic, TypeVar
from typing import Any
from homeassistant import config as conf_util
from homeassistant.config_entries import ConfigEntry
@ -37,8 +37,6 @@ from .typing import ConfigType, DiscoveryInfoType, VolDictType, VolSchemaType
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components"
_EntityT = TypeVar("_EntityT", bound=entity.Entity, default=entity.Entity)
@bind_hass
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
@ -62,7 +60,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
await entity_obj.async_update_ha_state(True)
class EntityComponent(Generic[_EntityT]):
class EntityComponent[_EntityT: entity.Entity = entity.Entity]:
"""The EntityComponent manages platforms that manage entities.
An example of an entity component is 'light', which manages platforms such

View File

@ -36,11 +36,6 @@ REQUEST_REFRESH_DEFAULT_COOLDOWN = 10
REQUEST_REFRESH_DEFAULT_IMMEDIATE = True
_DataT = TypeVar("_DataT", default=dict[str, Any])
_DataUpdateCoordinatorT = TypeVar(
"_DataUpdateCoordinatorT",
bound="DataUpdateCoordinator[Any]",
default="DataUpdateCoordinator[dict[str, Any]]",
)
class UpdateFailed(HomeAssistantError):
@ -564,7 +559,11 @@ class BaseCoordinatorEntity[
"""
class CoordinatorEntity(BaseCoordinatorEntity[_DataUpdateCoordinatorT]):
class CoordinatorEntity[
_DataUpdateCoordinatorT: DataUpdateCoordinator[Any] = DataUpdateCoordinator[
dict[str, Any]
]
](BaseCoordinatorEntity[_DataUpdateCoordinatorT]):
"""A class for entities using DataUpdateCoordinator."""
def __init__(

View File

@ -6,12 +6,10 @@ Custom for type checking. See stub file.
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Generic, TypeVar
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
from typing import Any
class EventType(str, Generic[_DataT]):
class EventType[_DataT: Mapping[str, Any] = Mapping[str, Any]](str):
"""Custom type for Event.event_type.
At runtime this is a generic subclass of str.

View File

@ -8,7 +8,9 @@ __all__ = [
"EventType",
]
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=Mapping[str, Any])
_DataT = TypeVar( # needs to be invariant
"_DataT", bound=Mapping[str, Any], default=Mapping[str, Any]
)
class EventType(Generic[_DataT]):
"""Custom type for Event.event_type. At runtime delegated to str.

View File

@ -25,7 +25,7 @@ import os
import pathlib
import time
from types import FrameType, ModuleType
from typing import Any, Literal, NoReturn, TypeVar
from typing import Any, Literal, NoReturn
from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
@ -113,8 +113,6 @@ from .testing_config.custom_components.test_constant_deprecation import (
import_deprecated_constant,
)
_DataT = TypeVar("_DataT", bound=Mapping[str, Any], default=dict[str, Any])
_LOGGER = logging.getLogger(__name__)
INSTANCES = []
CLIENT_ID = "https://example.com/app"
@ -1544,7 +1542,7 @@ def mock_platform(
module_cache[platform_path] = module or Mock()
def async_capture_events(
def async_capture_events[_DataT: Mapping[str, Any] = dict[str, Any]](
hass: HomeAssistant, event_name: EventType[_DataT] | str
) -> list[Event[_DataT]]:
"""Create a helper that captures events."""