Use assignment expressions 03 (#57710)
parent
2a8eaf0e0f
commit
238b488642
|
@ -192,8 +192,7 @@ def _async_register_clientsession_shutdown(
|
|||
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
|
||||
)
|
||||
|
||||
config_entry = config_entries.current_entry.get()
|
||||
if not config_entry:
|
||||
if not (config_entry := config_entries.current_entry.get()):
|
||||
return
|
||||
|
||||
config_entry.async_on_unload(unsub)
|
||||
|
|
|
@ -328,9 +328,8 @@ def async_numeric_state( # noqa: C901
|
|||
|
||||
if isinstance(entity, str):
|
||||
entity_id = entity
|
||||
entity = hass.states.get(entity)
|
||||
|
||||
if entity is None:
|
||||
if (entity := hass.states.get(entity)) is None:
|
||||
raise ConditionErrorMessage("numeric_state", f"unknown entity {entity_id}")
|
||||
else:
|
||||
entity_id = entity.entity_id
|
||||
|
@ -371,8 +370,7 @@ def async_numeric_state( # noqa: C901
|
|||
|
||||
if below is not None:
|
||||
if isinstance(below, str):
|
||||
below_entity = hass.states.get(below)
|
||||
if not below_entity:
|
||||
if not (below_entity := hass.states.get(below)):
|
||||
raise ConditionErrorMessage(
|
||||
"numeric_state", f"unknown 'below' entity {below}"
|
||||
)
|
||||
|
@ -400,8 +398,7 @@ def async_numeric_state( # noqa: C901
|
|||
|
||||
if above is not None:
|
||||
if isinstance(above, str):
|
||||
above_entity = hass.states.get(above)
|
||||
if not above_entity:
|
||||
if not (above_entity := hass.states.get(above)):
|
||||
raise ConditionErrorMessage(
|
||||
"numeric_state", f"unknown 'above' entity {above}"
|
||||
)
|
||||
|
@ -497,9 +494,8 @@ def state(
|
|||
|
||||
if isinstance(entity, str):
|
||||
entity_id = entity
|
||||
entity = hass.states.get(entity)
|
||||
|
||||
if entity is None:
|
||||
if (entity := hass.states.get(entity)) is None:
|
||||
raise ConditionErrorMessage("state", f"unknown entity {entity_id}")
|
||||
else:
|
||||
entity_id = entity.entity_id
|
||||
|
@ -526,8 +522,7 @@ def state(
|
|||
isinstance(req_state_value, str)
|
||||
and INPUT_ENTITY_ID.match(req_state_value) is not None
|
||||
):
|
||||
state_entity = hass.states.get(req_state_value)
|
||||
if not state_entity:
|
||||
if not (state_entity := hass.states.get(req_state_value)):
|
||||
raise ConditionErrorMessage(
|
||||
"state", f"the 'state' entity {req_state_value} is unavailable"
|
||||
)
|
||||
|
@ -738,8 +733,7 @@ def time(
|
|||
if after is None:
|
||||
after = dt_util.dt.time(0)
|
||||
elif isinstance(after, str):
|
||||
after_entity = hass.states.get(after)
|
||||
if not after_entity:
|
||||
if not (after_entity := hass.states.get(after)):
|
||||
raise ConditionErrorMessage("time", f"unknown 'after' entity {after}")
|
||||
if after_entity.domain == "input_datetime":
|
||||
after = dt_util.dt.time(
|
||||
|
@ -763,8 +757,7 @@ def time(
|
|||
if before is None:
|
||||
before = dt_util.dt.time(23, 59, 59, 999999)
|
||||
elif isinstance(before, str):
|
||||
before_entity = hass.states.get(before)
|
||||
if not before_entity:
|
||||
if not (before_entity := hass.states.get(before)):
|
||||
raise ConditionErrorMessage("time", f"unknown 'before' entity {before}")
|
||||
if before_entity.domain == "input_datetime":
|
||||
before = dt_util.dt.time(
|
||||
|
@ -840,9 +833,8 @@ def zone(
|
|||
|
||||
if isinstance(zone_ent, str):
|
||||
zone_ent_id = zone_ent
|
||||
zone_ent = hass.states.get(zone_ent)
|
||||
|
||||
if zone_ent is None:
|
||||
if (zone_ent := hass.states.get(zone_ent)) is None:
|
||||
raise ConditionErrorMessage("zone", f"unknown zone {zone_ent_id}")
|
||||
|
||||
if entity is None:
|
||||
|
@ -850,9 +842,8 @@ def zone(
|
|||
|
||||
if isinstance(entity, str):
|
||||
entity_id = entity
|
||||
entity = hass.states.get(entity)
|
||||
|
||||
if entity is None:
|
||||
if (entity := hass.states.get(entity)) is None:
|
||||
raise ConditionErrorMessage("zone", f"unknown entity {entity_id}")
|
||||
else:
|
||||
entity_id = entity.entity_id
|
||||
|
@ -1029,9 +1020,7 @@ def async_extract_devices(config: ConfigType | Template) -> set[str]:
|
|||
if condition != "device":
|
||||
continue
|
||||
|
||||
device_id = config.get(CONF_DEVICE_ID)
|
||||
|
||||
if device_id is not None:
|
||||
if (device_id := config.get(CONF_DEVICE_ID)) is not None:
|
||||
referenced.add(device_id)
|
||||
|
||||
return referenced
|
||||
|
|
|
@ -129,14 +129,10 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
req = http.current_request.get()
|
||||
|
||||
if req is None:
|
||||
if (req := http.current_request.get()) is None:
|
||||
raise RuntimeError("No current request in context")
|
||||
|
||||
ha_host = req.headers.get(HEADER_FRONTEND_BASE)
|
||||
|
||||
if ha_host is None:
|
||||
if (ha_host := req.headers.get(HEADER_FRONTEND_BASE)) is None:
|
||||
raise RuntimeError("No header in request")
|
||||
|
||||
return f"{ha_host}{AUTH_CALLBACK_PATH}"
|
||||
|
@ -501,9 +497,7 @@ async def async_oauth2_request(
|
|||
@callback
|
||||
def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
|
||||
"""JWT encode data."""
|
||||
secret = hass.data.get(DATA_JWT_SECRET)
|
||||
|
||||
if secret is None:
|
||||
if (secret := hass.data.get(DATA_JWT_SECRET)) is None:
|
||||
secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex()
|
||||
|
||||
return jwt.encode(data, secret, algorithm="HS256")
|
||||
|
|
|
@ -38,8 +38,7 @@ class _BaseFlowManagerView(HomeAssistantView):
|
|||
|
||||
data = result.copy()
|
||||
|
||||
schema = data["data_schema"]
|
||||
if schema is None:
|
||||
if (schema := data["data_schema"]) is None:
|
||||
data["data_schema"] = []
|
||||
else:
|
||||
data["data_schema"] = voluptuous_serialize.convert(
|
||||
|
|
|
@ -111,9 +111,7 @@ def async_listen_platform(
|
|||
|
||||
async def discovery_platform_listener(discovered: DiscoveryDict) -> None:
|
||||
"""Listen for platform discovery events."""
|
||||
platform = discovered["platform"]
|
||||
|
||||
if not platform:
|
||||
if not (platform := discovered["platform"]):
|
||||
return
|
||||
|
||||
task = hass.async_run_hass_job(job, platform, discovered.get("discovered"))
|
||||
|
|
|
@ -727,8 +727,7 @@ current_platform: ContextVar[EntityPlatform | None] = ContextVar(
|
|||
@callback
|
||||
def async_get_current_platform() -> EntityPlatform:
|
||||
"""Get the current platform from context."""
|
||||
platform = current_platform.get()
|
||||
if platform is None:
|
||||
if (platform := current_platform.get()) is None:
|
||||
raise RuntimeError("Cannot get non-set current platform")
|
||||
return platform
|
||||
|
||||
|
|
|
@ -33,8 +33,7 @@ SPEECH_TYPE_SSML = "ssml"
|
|||
@bind_hass
|
||||
def async_register(hass: HomeAssistant, handler: IntentHandler) -> None:
|
||||
"""Register an intent with Home Assistant."""
|
||||
intents = hass.data.get(DATA_KEY)
|
||||
if intents is None:
|
||||
if (intents := hass.data.get(DATA_KEY)) is None:
|
||||
intents = hass.data[DATA_KEY] = {}
|
||||
|
||||
assert handler.intent_type is not None, "intent_type cannot be None"
|
||||
|
|
|
@ -51,9 +51,7 @@ def find_coordinates(
|
|||
hass: HomeAssistant, entity_id: str, recursion_history: list | None = None
|
||||
) -> str | None:
|
||||
"""Find the gps coordinates of the entity in the form of '90.000,180.000'."""
|
||||
entity_state = hass.states.get(entity_id)
|
||||
|
||||
if entity_state is None:
|
||||
if (entity_state := hass.states.get(entity_id)) is None:
|
||||
_LOGGER.error("Unable to find entity %s", entity_id)
|
||||
return None
|
||||
|
||||
|
|
|
@ -118,8 +118,7 @@ def get_url(
|
|||
|
||||
def _get_request_host() -> str | None:
|
||||
"""Get the host address of the current request."""
|
||||
request = http.current_request.get()
|
||||
if request is None:
|
||||
if (request := http.current_request.get()) is None:
|
||||
raise NoURLAvailableError
|
||||
return yarl.URL(request.url).host
|
||||
|
||||
|
|
|
@ -78,8 +78,7 @@ class KeyedRateLimit:
|
|||
if rate_limit is None:
|
||||
return None
|
||||
|
||||
last_triggered = self._last_triggered.get(key)
|
||||
if not last_triggered:
|
||||
if not (last_triggered := self._last_triggered.get(key)):
|
||||
return None
|
||||
|
||||
next_call_time = last_triggered + rate_limit
|
||||
|
|
|
@ -953,8 +953,7 @@ class Script:
|
|||
variables: ScriptVariables | None = None,
|
||||
) -> None:
|
||||
"""Initialize the script."""
|
||||
all_scripts = hass.data.get(DATA_SCRIPTS)
|
||||
if not all_scripts:
|
||||
if not (all_scripts := hass.data.get(DATA_SCRIPTS)):
|
||||
all_scripts = hass.data[DATA_SCRIPTS] = []
|
||||
hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
|
||||
|
@ -1273,8 +1272,7 @@ class Script:
|
|||
config_cache_key = config.template
|
||||
else:
|
||||
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
|
||||
cond = self._config_cache.get(config_cache_key)
|
||||
if not cond:
|
||||
if not (cond := self._config_cache.get(config_cache_key)):
|
||||
cond = await condition.async_from_config(self._hass, config, False)
|
||||
self._config_cache[config_cache_key] = cond
|
||||
return cond
|
||||
|
@ -1297,8 +1295,7 @@ class Script:
|
|||
return sub_script
|
||||
|
||||
def _get_repeat_script(self, step: int) -> Script:
|
||||
sub_script = self._repeat_script.get(step)
|
||||
if not sub_script:
|
||||
if not (sub_script := self._repeat_script.get(step)):
|
||||
sub_script = self._prep_repeat_script(step)
|
||||
self._repeat_script[step] = sub_script
|
||||
return sub_script
|
||||
|
@ -1351,8 +1348,7 @@ class Script:
|
|||
return {"choices": choices, "default": default_script}
|
||||
|
||||
async def _async_get_choose_data(self, step: int) -> _ChooseData:
|
||||
choose_data = self._choose_data.get(step)
|
||||
if not choose_data:
|
||||
if not (choose_data := self._choose_data.get(step)):
|
||||
choose_data = await self._async_prep_choose_data(step)
|
||||
self._choose_data[step] = choose_data
|
||||
return choose_data
|
||||
|
|
|
@ -22,9 +22,7 @@ def validate_selector(config: Any) -> dict:
|
|||
|
||||
selector_type = list(config)[0]
|
||||
|
||||
selector_class = SELECTORS.get(selector_type)
|
||||
|
||||
if selector_class is None:
|
||||
if (selector_class := SELECTORS.get(selector_type)) is None:
|
||||
raise vol.Invalid(f"Unknown selector type {selector_type} found")
|
||||
|
||||
# Selectors can be empty
|
||||
|
|
|
@ -396,10 +396,11 @@ async def async_extract_config_entry_ids(
|
|||
|
||||
# Some devices may have no entities
|
||||
for device_id in referenced.referenced_devices:
|
||||
if device_id in dev_reg.devices:
|
||||
device = dev_reg.async_get(device_id)
|
||||
if device is not None:
|
||||
config_entry_ids.update(device.config_entries)
|
||||
if (
|
||||
device_id in dev_reg.devices
|
||||
and (device := dev_reg.async_get(device_id)) is not None
|
||||
):
|
||||
config_entry_ids.update(device.config_entries)
|
||||
|
||||
for entity_id in referenced.referenced | referenced.indirectly_referenced:
|
||||
entry = ent_reg.async_get(entity_id)
|
||||
|
|
|
@ -813,8 +813,7 @@ class TemplateState(State):
|
|||
|
||||
|
||||
def _collect_state(hass: HomeAssistant, entity_id: str) -> None:
|
||||
entity_collect = hass.data.get(_RENDER_INFO)
|
||||
if entity_collect is not None:
|
||||
if (entity_collect := hass.data.get(_RENDER_INFO)) is not None:
|
||||
entity_collect.entities.add(entity_id)
|
||||
|
||||
|
||||
|
@ -1188,8 +1187,7 @@ def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
|
|||
|
||||
def now(hass: HomeAssistant) -> datetime:
|
||||
"""Record fetching now."""
|
||||
render_info = hass.data.get(_RENDER_INFO)
|
||||
if render_info is not None:
|
||||
if (render_info := hass.data.get(_RENDER_INFO)) is not None:
|
||||
render_info.has_time = True
|
||||
|
||||
return dt_util.now()
|
||||
|
@ -1197,8 +1195,7 @@ def now(hass: HomeAssistant) -> datetime:
|
|||
|
||||
def utcnow(hass: HomeAssistant) -> datetime:
|
||||
"""Record fetching utcnow."""
|
||||
render_info = hass.data.get(_RENDER_INFO)
|
||||
if render_info is not None:
|
||||
if (render_info := hass.data.get(_RENDER_INFO)) is not None:
|
||||
render_info.has_time = True
|
||||
|
||||
return dt_util.utcnow()
|
||||
|
@ -1843,9 +1840,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
|
|||
# any instance of this.
|
||||
return super().compile(source, name, filename, raw, defer_init)
|
||||
|
||||
cached = self.template_cache.get(source)
|
||||
|
||||
if cached is None:
|
||||
if (cached := self.template_cache.get(source)) is None:
|
||||
cached = self.template_cache[source] = super().compile(source)
|
||||
|
||||
return cached
|
||||
|
|
|
@ -113,8 +113,7 @@ def trace_id_get() -> tuple[tuple[str, str], str] | None:
|
|||
|
||||
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
|
||||
"""Push an element to the top of a trace stack."""
|
||||
trace_stack = trace_stack_var.get()
|
||||
if trace_stack is None:
|
||||
if (trace_stack := trace_stack_var.get()) is None:
|
||||
trace_stack = []
|
||||
trace_stack_var.set(trace_stack)
|
||||
trace_stack.append(node)
|
||||
|
@ -149,8 +148,7 @@ def trace_path_pop(count: int) -> None:
|
|||
|
||||
def trace_path_get() -> str:
|
||||
"""Return a string representing the current location in the config tree."""
|
||||
path = trace_path_stack_cv.get()
|
||||
if not path:
|
||||
if not (path := trace_path_stack_cv.get()):
|
||||
return ""
|
||||
return "/".join(path)
|
||||
|
||||
|
@ -160,12 +158,10 @@ def trace_append_element(
|
|||
maxlen: int | None = None,
|
||||
) -> None:
|
||||
"""Append a TraceElement to trace[path]."""
|
||||
path = trace_element.path
|
||||
trace = trace_cv.get()
|
||||
if trace is None:
|
||||
if (trace := trace_cv.get()) is None:
|
||||
trace = {}
|
||||
trace_cv.set(trace)
|
||||
if path not in trace:
|
||||
if (path := trace_element.path) not in trace:
|
||||
trace[path] = deque(maxlen=maxlen)
|
||||
trace[path].append(trace_element)
|
||||
|
||||
|
@ -213,16 +209,14 @@ class StopReason:
|
|||
|
||||
def script_execution_set(reason: str) -> None:
|
||||
"""Set stop reason."""
|
||||
data = script_execution_cv.get()
|
||||
if data is None:
|
||||
if (data := script_execution_cv.get()) is None:
|
||||
return
|
||||
data.script_execution = reason
|
||||
|
||||
|
||||
def script_execution_get() -> str | None:
|
||||
"""Return the current trace."""
|
||||
data = script_execution_cv.get()
|
||||
if data is None:
|
||||
if (data := script_execution_cv.get()) is None:
|
||||
return None
|
||||
return data.script_execution
|
||||
|
||||
|
|
|
@ -146,9 +146,7 @@ async def async_get_custom_components(
|
|||
hass: HomeAssistant,
|
||||
) -> dict[str, Integration]:
|
||||
"""Return cached list of custom integrations."""
|
||||
reg_or_evt = hass.data.get(DATA_CUSTOM_COMPONENTS)
|
||||
|
||||
if reg_or_evt is None:
|
||||
if (reg_or_evt := hass.data.get(DATA_CUSTOM_COMPONENTS)) is None:
|
||||
evt = hass.data[DATA_CUSTOM_COMPONENTS] = asyncio.Event()
|
||||
|
||||
reg = await _async_get_custom_components(hass)
|
||||
|
@ -543,8 +541,7 @@ class Integration:
|
|||
|
||||
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
|
||||
"""Get an integration."""
|
||||
cache = hass.data.get(DATA_INTEGRATIONS)
|
||||
if cache is None:
|
||||
if (cache := hass.data.get(DATA_INTEGRATIONS)) is None:
|
||||
if not _async_mount_config_dir(hass):
|
||||
raise IntegrationNotFound(domain)
|
||||
cache = hass.data[DATA_INTEGRATIONS] = {}
|
||||
|
@ -553,12 +550,11 @@ async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration
|
|||
|
||||
if isinstance(int_or_evt, asyncio.Event):
|
||||
await int_or_evt.wait()
|
||||
int_or_evt = cache.get(domain, _UNDEF)
|
||||
|
||||
# When we have waited and it's _UNDEF, it doesn't exist
|
||||
# We don't cache that it doesn't exist, or else people can't fix it
|
||||
# and then restart, because their config will never be valid.
|
||||
if int_or_evt is _UNDEF:
|
||||
if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF:
|
||||
raise IntegrationNotFound(domain)
|
||||
|
||||
if int_or_evt is not _UNDEF:
|
||||
|
@ -630,8 +626,7 @@ def _load_file(
|
|||
with suppress(KeyError):
|
||||
return hass.data[DATA_COMPONENTS][comp_or_platform] # type: ignore
|
||||
|
||||
cache = hass.data.get(DATA_COMPONENTS)
|
||||
if cache is None:
|
||||
if (cache := hass.data.get(DATA_COMPONENTS)) is None:
|
||||
if not _async_mount_config_dir(hass):
|
||||
return None
|
||||
cache = hass.data[DATA_COMPONENTS] = {}
|
||||
|
|
|
@ -60,8 +60,7 @@ async def async_get_integration_with_requirements(
|
|||
if hass.config.skip_pip:
|
||||
return integration
|
||||
|
||||
cache = hass.data.get(DATA_INTEGRATIONS_WITH_REQS)
|
||||
if cache is None:
|
||||
if (cache := hass.data.get(DATA_INTEGRATIONS_WITH_REQS)) is None:
|
||||
cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {}
|
||||
|
||||
int_or_evt: Integration | asyncio.Event | None | UndefinedType = cache.get(
|
||||
|
@ -71,12 +70,10 @@ async def async_get_integration_with_requirements(
|
|||
if isinstance(int_or_evt, asyncio.Event):
|
||||
await int_or_evt.wait()
|
||||
|
||||
int_or_evt = cache.get(domain, UNDEFINED)
|
||||
|
||||
# When we have waited and it's UNDEFINED, it doesn't exist
|
||||
# We don't cache that it doesn't exist, or else people can't fix it
|
||||
# and then restart, because their config will never be valid.
|
||||
if int_or_evt is UNDEFINED:
|
||||
if (int_or_evt := cache.get(domain, UNDEFINED)) is UNDEFINED:
|
||||
raise IntegrationNotFound(domain)
|
||||
|
||||
if int_or_evt is not UNDEFINED:
|
||||
|
@ -154,8 +151,7 @@ async def async_process_requirements(
|
|||
This method is a coroutine. It will raise RequirementsNotFound
|
||||
if an requirement can't be satisfied.
|
||||
"""
|
||||
pip_lock = hass.data.get(DATA_PIP_LOCK)
|
||||
if pip_lock is None:
|
||||
if (pip_lock := hass.data.get(DATA_PIP_LOCK)) is None:
|
||||
pip_lock = hass.data[DATA_PIP_LOCK] = asyncio.Lock()
|
||||
install_failure_history = hass.data.get(DATA_INSTALL_FAILURE_HISTORY)
|
||||
if install_failure_history is None:
|
||||
|
|
|
@ -83,8 +83,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
|
|||
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
|
||||
"""Handle all exception inside the core loop."""
|
||||
kwargs = {}
|
||||
exception = context.get("exception")
|
||||
if exception:
|
||||
if exception := context.get("exception"):
|
||||
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
|
||||
|
||||
logging.getLogger(__package__).error(
|
||||
|
|
Loading…
Reference in New Issue