Fix memory leak when firing state_changed events (#72571)

pull/72573/head
J. Nick Koston 2022-05-26 17:54:26 -10:00 committed by GitHub
parent 465210784f
commit 049c06061c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 8 deletions

View File

@ -746,7 +746,7 @@ class LazyState(State):
def context(self) -> Context: # type: ignore[override]
"""State context."""
if self._context is None:
self._context = Context(id=None) # type: ignore[arg-type]
self._context = Context(id=None)
return self._context
@context.setter

View File

@ -37,7 +37,6 @@ from typing import (
)
from urllib.parse import urlparse
import attr
import voluptuous as vol
import yarl
@ -716,14 +715,26 @@ class HomeAssistant:
self._stopped.set()
@attr.s(slots=True, frozen=False)
class Context:
"""The context that triggered something."""
user_id: str | None = attr.ib(default=None)
parent_id: str | None = attr.ib(default=None)
id: str = attr.ib(factory=ulid_util.ulid)
origin_event: Event | None = attr.ib(default=None, eq=False)
__slots__ = ("user_id", "parent_id", "id", "origin_event")
def __init__(
self,
user_id: str | None = None,
parent_id: str | None = None,
id: str | None = None, # pylint: disable=redefined-builtin
) -> None:
"""Init the context."""
self.id = id or ulid_util.ulid()
self.user_id = user_id
self.parent_id = parent_id
self.origin_event: Event | None = None
def __eq__(self, other: Any) -> bool:
"""Compare contexts."""
return bool(self.__class__ == other.__class__ and self.id == other.id)
def as_dict(self) -> dict[str, str | None]:
"""Return a dictionary representation of the context."""
@ -1163,6 +1174,24 @@ class State:
context,
)
def expire(self) -> None:
"""Mark the state as old.
We give up the original reference to the context to ensure
the context can be garbage collected by replacing it with
a new one with the same id to ensure the old state
can still be examined for comparison against the new state.
Since we are always going to fire a EVENT_STATE_CHANGED event
after we remove a state from the state machine we need to make
sure we don't end up holding a reference to the original context
since it can never be garbage collected as each event would
reference the previous one.
"""
self.context = Context(
self.context.user_id, self.context.parent_id, self.context.id
)
def __eq__(self, other: Any) -> bool:
"""Return the comparison of the state."""
return ( # type: ignore[no-any-return]
@ -1303,6 +1332,7 @@ class StateMachine:
if old_state is None:
return False
old_state.expire()
self._bus.async_fire(
EVENT_STATE_CHANGED,
{"entity_id": entity_id, "old_state": old_state, "new_state": None},
@ -1396,7 +1426,6 @@ class StateMachine:
if context is None:
context = Context(id=ulid_util.ulid(dt_util.utc_to_timestamp(now)))
state = State(
entity_id,
new_state,
@ -1406,6 +1435,8 @@ class StateMachine:
context,
old_state is None,
)
if old_state is not None:
old_state.expire()
self._states[entity_id] = state
self._bus.async_fire(
EVENT_STATE_CHANGED,

View File

@ -6,9 +6,11 @@ import array
import asyncio
from datetime import datetime, timedelta
import functools
import gc
import logging
import os
from tempfile import TemporaryDirectory
from typing import Any
from unittest.mock import MagicMock, Mock, PropertyMock, patch
import pytest
@ -1829,3 +1831,46 @@ async def test_event_context(hass):
cancel2()
assert dummy_event2.context.origin_event == dummy_event
def _get_full_name(obj) -> str:
"""Get the full name of an object in memory."""
objtype = type(obj)
name = objtype.__name__
if module := getattr(objtype, "__module__", None):
return f"{module}.{name}"
return name
def _get_by_type(full_name: str) -> list[Any]:
"""Get all objects in memory with a specific type."""
return [obj for obj in gc.get_objects() if _get_full_name(obj) == full_name]
# The logger will hold a strong reference to the event for the life of the tests
# so we must patch it out
@pytest.mark.skipif(
not os.environ.get("DEBUG_MEMORY"),
reason="Takes too long on the CI",
)
@patch.object(ha._LOGGER, "debug", lambda *args: None)
async def test_state_changed_events_to_not_leak_contexts(hass):
"""Test state changed events do not leak contexts."""
gc.collect()
# Other tests can log Contexts which keep them in memory
# so we need to look at how many exist at the start
init_count = len(_get_by_type("homeassistant.core.Context"))
assert len(_get_by_type("homeassistant.core.Context")) == init_count
for i in range(20):
hass.states.async_set("light.switch", str(i))
await hass.async_block_till_done()
gc.collect()
assert len(_get_by_type("homeassistant.core.Context")) == init_count + 2
hass.states.async_remove("light.switch")
await hass.async_block_till_done()
gc.collect()
assert len(_get_by_type("homeassistant.core.Context")) == init_count