Fix memory leak when firing state_changed events (#72571)
parent
465210784f
commit
049c06061c
|
@ -746,7 +746,7 @@ class LazyState(State):
|
|||
def context(self) -> Context: # type: ignore[override]
|
||||
"""State context."""
|
||||
if self._context is None:
|
||||
self._context = Context(id=None) # type: ignore[arg-type]
|
||||
self._context = Context(id=None)
|
||||
return self._context
|
||||
|
||||
@context.setter
|
||||
|
|
|
@ -37,7 +37,6 @@ from typing import (
|
|||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
import yarl
|
||||
|
||||
|
@ -716,14 +715,26 @@ class HomeAssistant:
|
|||
self._stopped.set()
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=False)
|
||||
class Context:
|
||||
"""The context that triggered something."""
|
||||
|
||||
user_id: str | None = attr.ib(default=None)
|
||||
parent_id: str | None = attr.ib(default=None)
|
||||
id: str = attr.ib(factory=ulid_util.ulid)
|
||||
origin_event: Event | None = attr.ib(default=None, eq=False)
|
||||
__slots__ = ("user_id", "parent_id", "id", "origin_event")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
parent_id: str | None = None,
|
||||
id: str | None = None, # pylint: disable=redefined-builtin
|
||||
) -> None:
|
||||
"""Init the context."""
|
||||
self.id = id or ulid_util.ulid()
|
||||
self.user_id = user_id
|
||||
self.parent_id = parent_id
|
||||
self.origin_event: Event | None = None
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compare contexts."""
|
||||
return bool(self.__class__ == other.__class__ and self.id == other.id)
|
||||
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Return a dictionary representation of the context."""
|
||||
|
@ -1163,6 +1174,24 @@ class State:
|
|||
context,
|
||||
)
|
||||
|
||||
def expire(self) -> None:
|
||||
"""Mark the state as old.
|
||||
|
||||
We give up the original reference to the context to ensure
|
||||
the context can be garbage collected by replacing it with
|
||||
a new one with the same id to ensure the old state
|
||||
can still be examined for comparison against the new state.
|
||||
|
||||
Since we are always going to fire a EVENT_STATE_CHANGED event
|
||||
after we remove a state from the state machine we need to make
|
||||
sure we don't end up holding a reference to the original context
|
||||
since it can never be garbage collected as each event would
|
||||
reference the previous one.
|
||||
"""
|
||||
self.context = Context(
|
||||
self.context.user_id, self.context.parent_id, self.context.id
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return the comparison of the state."""
|
||||
return ( # type: ignore[no-any-return]
|
||||
|
@ -1303,6 +1332,7 @@ class StateMachine:
|
|||
if old_state is None:
|
||||
return False
|
||||
|
||||
old_state.expire()
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
{"entity_id": entity_id, "old_state": old_state, "new_state": None},
|
||||
|
@ -1396,7 +1426,6 @@ class StateMachine:
|
|||
|
||||
if context is None:
|
||||
context = Context(id=ulid_util.ulid(dt_util.utc_to_timestamp(now)))
|
||||
|
||||
state = State(
|
||||
entity_id,
|
||||
new_state,
|
||||
|
@ -1406,6 +1435,8 @@ class StateMachine:
|
|||
context,
|
||||
old_state is None,
|
||||
)
|
||||
if old_state is not None:
|
||||
old_state.expire()
|
||||
self._states[entity_id] = state
|
||||
self._bus.async_fire(
|
||||
EVENT_STATE_CHANGED,
|
||||
|
|
|
@ -6,9 +6,11 @@ import array
|
|||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -1829,3 +1831,46 @@ async def test_event_context(hass):
|
|||
cancel2()
|
||||
|
||||
assert dummy_event2.context.origin_event == dummy_event
|
||||
|
||||
|
||||
def _get_full_name(obj) -> str:
|
||||
"""Get the full name of an object in memory."""
|
||||
objtype = type(obj)
|
||||
name = objtype.__name__
|
||||
if module := getattr(objtype, "__module__", None):
|
||||
return f"{module}.{name}"
|
||||
return name
|
||||
|
||||
|
||||
def _get_by_type(full_name: str) -> list[Any]:
|
||||
"""Get all objects in memory with a specific type."""
|
||||
return [obj for obj in gc.get_objects() if _get_full_name(obj) == full_name]
|
||||
|
||||
|
||||
# The logger will hold a strong reference to the event for the life of the tests
|
||||
# so we must patch it out
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("DEBUG_MEMORY"),
|
||||
reason="Takes too long on the CI",
|
||||
)
|
||||
@patch.object(ha._LOGGER, "debug", lambda *args: None)
|
||||
async def test_state_changed_events_to_not_leak_contexts(hass):
|
||||
"""Test state changed events do not leak contexts."""
|
||||
gc.collect()
|
||||
# Other tests can log Contexts which keep them in memory
|
||||
# so we need to look at how many exist at the start
|
||||
init_count = len(_get_by_type("homeassistant.core.Context"))
|
||||
|
||||
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
||||
for i in range(20):
|
||||
hass.states.async_set("light.switch", str(i))
|
||||
await hass.async_block_till_done()
|
||||
gc.collect()
|
||||
|
||||
assert len(_get_by_type("homeassistant.core.Context")) == init_count + 2
|
||||
|
||||
hass.states.async_remove("light.switch")
|
||||
await hass.async_block_till_done()
|
||||
gc.collect()
|
||||
|
||||
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
||||
|
|
Loading…
Reference in New Issue