Make TypeVars private (2) (#68206)

pull/68310/head
Marc Mueller 2022-03-17 19:09:55 +01:00 committed by GitHub
parent be7ef6115c
commit eae0c75620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 29 deletions

View File

@ -4,13 +4,15 @@ from __future__ import annotations
from enum import Enum
from typing import Any, TypeVar
T = TypeVar("T", bound="StrEnum")
_StrEnumT = TypeVar("_StrEnumT", bound="StrEnum")
class StrEnum(str, Enum):
"""Partial backport of Python 3.11's StrEnum for our basic use cases."""
def __new__(cls: type[T], value: str, *args: Any, **kwargs: Any) -> T:
def __new__(
cls: type[_StrEnumT], value: str, *args: Any, **kwargs: Any
) -> _StrEnumT:
"""Create a new StrEnum instance."""
if not isinstance(value, str):
raise TypeError(f"{value!r} is not a string")

View File

@ -102,7 +102,7 @@ sun_event = vol.All(vol.Lower, vol.Any(SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE))
port = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535))
# typing typevar
T = TypeVar("T")
_T = TypeVar("_T")
def path(value: Any) -> str:
@ -253,20 +253,20 @@ def ensure_list(value: None) -> list[Any]:
@overload
def ensure_list(value: list[T]) -> list[T]:
def ensure_list(value: list[_T]) -> list[_T]:
...
@overload
def ensure_list(value: list[T] | T) -> list[T]:
def ensure_list(value: list[_T] | _T) -> list[_T]:
...
def ensure_list(value: T | None) -> list[T] | list[Any]:
def ensure_list(value: _T | None) -> list[_T] | list[Any]:
"""Wrap value in list if it is not one."""
if value is None:
return []
return cast("list[T]", value) if isinstance(value, list) else [value]
return cast("list[_T]", value) if isinstance(value, list) else [value]
def entity_id(value: Any) -> str:
@ -467,7 +467,7 @@ def time_period_seconds(value: float | str) -> timedelta:
time_period = vol.Any(time_period_str, time_period_seconds, timedelta, time_period_dict)
def match_all(value: T) -> T:
def match_all(value: _T) -> _T:
"""Validate that matches all values."""
return value
@ -483,7 +483,7 @@ positive_time_period_dict = vol.All(time_period_dict, positive_timedelta)
positive_time_period = vol.All(time_period, positive_timedelta)
def remove_falsy(value: list[T]) -> list[T]:
def remove_falsy(value: list[_T]) -> list[_T]:
"""Remove falsy values from a list."""
return [v for v in value if v]
@ -510,7 +510,7 @@ def slug(value: Any) -> str:
def schema_with_slug_keys(
value_schema: T | Callable, *, slug_validator: Callable[[Any], str] = slug
value_schema: _T | Callable, *, slug_validator: Callable[[Any], str] = slug
) -> Callable:
"""Ensure dicts have slugs as keys.

View File

@ -15,7 +15,7 @@ _LOGGER = logging.getLogger(__name__)
# Keep track of integrations already reported to prevent flooding
_REPORTED_INTEGRATIONS: set[str] = set()
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) # pylint: disable=invalid-name
_CallableT = TypeVar("_CallableT", bound=Callable)
def get_integration_frame(
@ -113,7 +113,7 @@ def report_integration(
)
def warn_use(func: CALLABLE_T, what: str) -> CALLABLE_T:
def warn_use(func: _CallableT, what: str) -> _CallableT:
"""Mock a function to warn when it was about to be used."""
if asyncio.iscoroutinefunction(func):
@ -127,4 +127,4 @@ def warn_use(func: CALLABLE_T, what: str) -> CALLABLE_T:
def report_use(*args: Any, **kwargs: Any) -> None:
report(what)
return cast(CALLABLE_T, report_use)
return cast(_CallableT, report_use)

View File

@ -9,9 +9,9 @@ from typing import TypeVar, cast
from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass
T = TypeVar("T")
_T = TypeVar("_T")
FUNC = Callable[[HomeAssistant], T]
FUNC = Callable[[HomeAssistant], _T]
def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
@ -26,30 +26,30 @@ def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
@bind_hass
@functools.wraps(func)
def wrapped(hass: HomeAssistant) -> T:
def wrapped(hass: HomeAssistant) -> _T:
if data_key not in hass.data:
hass.data[data_key] = func(hass)
return cast(T, hass.data[data_key])
return cast(_T, hass.data[data_key])
return wrapped
@bind_hass
@functools.wraps(func)
async def async_wrapped(hass: HomeAssistant) -> T:
async def async_wrapped(hass: HomeAssistant) -> _T:
if data_key not in hass.data:
evt = hass.data[data_key] = asyncio.Event()
result = await func(hass)
hass.data[data_key] = result
evt.set()
return cast(T, result)
return cast(_T, result)
obj_or_evt = hass.data[data_key]
if isinstance(obj_or_evt, asyncio.Event):
await obj_or_evt.wait()
return cast(T, hass.data[data_key])
return cast(_T, hass.data[data_key])
return cast(T, obj_or_evt)
return cast(_T, obj_or_evt)
return async_wrapped

View File

@ -23,14 +23,14 @@ from .debounce import Debouncer
REQUEST_REFRESH_DEFAULT_COOLDOWN = 10
REQUEST_REFRESH_DEFAULT_IMMEDIATE = True
T = TypeVar("T")
_T = TypeVar("_T")
class UpdateFailed(Exception):
"""Raised when an update has failed."""
class DataUpdateCoordinator(Generic[T]):
class DataUpdateCoordinator(Generic[_T]):
"""Class to manage fetching data from single endpoint."""
def __init__(
@ -40,7 +40,7 @@ class DataUpdateCoordinator(Generic[T]):
*,
name: str,
update_interval: timedelta | None = None,
update_method: Callable[[], Awaitable[T]] | None = None,
update_method: Callable[[], Awaitable[_T]] | None = None,
request_refresh_debouncer: Debouncer | None = None,
) -> None:
"""Initialize global data updater."""
@ -56,7 +56,7 @@ class DataUpdateCoordinator(Generic[T]):
# to make sure the first update was successful.
# Set type to just T to remove annoying checks that data is not None
# when it was already checked during setup.
self.data: T = None # type: ignore[assignment]
self.data: _T = None # type: ignore[assignment]
self._listeners: list[CALLBACK_TYPE] = []
self._job = HassJob(self._handle_refresh_interval)
@ -140,7 +140,7 @@ class DataUpdateCoordinator(Generic[T]):
"""
await self._debounced_refresh.async_call()
async def _async_update_data(self) -> T:
async def _async_update_data(self) -> _T:
"""Fetch the latest data from the source."""
if self.update_method is None:
raise NotImplementedError("Update method not implemented")
@ -265,7 +265,7 @@ class DataUpdateCoordinator(Generic[T]):
update_callback()
@callback
def async_set_updated_data(self, data: T) -> None:
def async_set_updated_data(self, data: _T) -> None:
"""Manually update data, notify listeners and reset refresh interval."""
if self._unsub_refresh:
self._unsub_refresh()
@ -295,10 +295,10 @@ class DataUpdateCoordinator(Generic[T]):
self._unsub_refresh = None
class CoordinatorEntity(Generic[T], entity.Entity):
class CoordinatorEntity(Generic[_T], entity.Entity):
"""A class for entities using DataUpdateCoordinator."""
def __init__(self, coordinator: DataUpdateCoordinator[T]) -> None:
def __init__(self, coordinator: DataUpdateCoordinator[_T]) -> None:
"""Create the entity with a DataUpdateCoordinator."""
self.coordinator = coordinator