Move thread safety in label_registry sooner (#117026)
parent
9557ea902e
commit
fc3c384e0a
|
@ -121,6 +121,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
|||
description: str | None = None,
|
||||
) -> LabelEntry:
|
||||
"""Create a new label."""
|
||||
self.hass.verify_event_loop_thread("async_create")
|
||||
if label := self.async_get_label_by_name(name):
|
||||
raise ValueError(
|
||||
f"The name {name} ({label.normalized_name}) is already in use"
|
||||
|
@ -139,7 +140,7 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
|||
label_id = label.label_id
|
||||
self.labels[label_id] = label
|
||||
self.async_schedule_save()
|
||||
self.hass.bus.async_fire(
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_LABEL_REGISTRY_UPDATED,
|
||||
EventLabelRegistryUpdatedData(
|
||||
action="create",
|
||||
|
@ -151,8 +152,9 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
|||
@callback
|
||||
def async_delete(self, label_id: str) -> None:
|
||||
"""Delete label."""
|
||||
self.hass.verify_event_loop_thread("async_delete")
|
||||
del self.labels[label_id]
|
||||
self.hass.bus.async_fire(
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_LABEL_REGISTRY_UPDATED,
|
||||
EventLabelRegistryUpdatedData(
|
||||
action="remove",
|
||||
|
@ -190,10 +192,11 @@ class LabelRegistry(BaseRegistry[LabelRegistryStoreData]):
|
|||
if not changes:
|
||||
return old
|
||||
|
||||
self.hass.verify_event_loop_thread("async_update")
|
||||
new = self.labels[label_id] = dataclasses.replace(old, **changes) # type: ignore[arg-type]
|
||||
|
||||
self.async_schedule_save()
|
||||
self.hass.bus.async_fire(
|
||||
self.hass.bus.async_fire_internal(
|
||||
EVENT_LABEL_REGISTRY_UPDATED,
|
||||
EventLabelRegistryUpdatedData(
|
||||
action="update",
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for the Label Registry."""
|
||||
|
||||
from functools import partial
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
@ -454,3 +455,45 @@ async def test_labels_removed_from_entities(
|
|||
assert len(entries) == 0
|
||||
entries = er.async_entries_for_label(entity_registry, label2.label_id)
|
||||
assert len(entries) == 0
|
||||
|
||||
|
||||
async def test_async_create_thread_safety(
|
||||
hass: HomeAssistant,
|
||||
label_registry: lr.LabelRegistry,
|
||||
) -> None:
|
||||
"""Test async_create raises when called from wrong thread."""
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_create from a thread. Please report this issue.",
|
||||
):
|
||||
await hass.async_add_executor_job(label_registry.async_create, "any")
|
||||
|
||||
|
||||
async def test_async_delete_thread_safety(
|
||||
hass: HomeAssistant,
|
||||
label_registry: lr.LabelRegistry,
|
||||
) -> None:
|
||||
"""Test async_delete raises when called from wrong thread."""
|
||||
any_label = label_registry.async_create("any")
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_delete from a thread. Please report this issue.",
|
||||
):
|
||||
await hass.async_add_executor_job(label_registry.async_delete, any_label)
|
||||
|
||||
|
||||
async def test_async_update_thread_safety(
|
||||
hass: HomeAssistant,
|
||||
label_registry: lr.LabelRegistry,
|
||||
) -> None:
|
||||
"""Test async_update raises when called from wrong thread."""
|
||||
any_label = label_registry.async_create("any")
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_update from a thread. Please report this issue.",
|
||||
):
|
||||
await hass.async_add_executor_job(
|
||||
partial(label_registry.async_update, any_label.label_id, name="new name")
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue