Make StatesMetaManager thread-safe when an entity_id is fully deleted from the database and than re-added (#89732)

* refactor to make StatesMetaManager threadsafe

* refactor to make StatesMetaManager threadsafe

* refactor to make StatesMetaManager threadsafe

* refactor to make StatesMetaManager threadsafe

* reduce

* comments
pull/89748/head
J. Nick Koston 2023-03-15 02:54:02 -10:00 committed by GitHub
parent 6a01c3369d
commit a244749712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 175 additions and 42 deletions

View File

@ -155,7 +155,7 @@ class EventProcessor:
if self.entity_ids:
instance = get_instance(self.hass)
entity_id_to_metadata_id = instance.states_meta_manager.get_many(
self.entity_ids, session
self.entity_ids, session, False
)
metadata_ids = [
metadata_id

View File

@ -1027,7 +1027,7 @@ class Recorder(threading.Thread):
states_meta_manager = self.states_meta_manager
if pending_states_meta := states_meta_manager.get_pending(entity_id):
dbstate.states_meta_rel = pending_states_meta
elif metadata_id := states_meta_manager.get(entity_id, event_session):
elif metadata_id := states_meta_manager.get(entity_id, event_session, True):
dbstate.metadata_id = metadata_id
else:
states_meta = StatesMeta(entity_id=entity_id)

View File

@ -242,7 +242,7 @@ def get_significant_states_with_session(
if entity_ids:
instance = recorder.get_instance(hass)
entity_id_to_metadata_id = instance.states_meta_manager.get_many(
entity_ids, session
entity_ids, session, False
)
metadata_ids = [
metadata_id
@ -365,7 +365,7 @@ def state_changes_during_period(
entity_id_to_metadata_id = None
if entity_id:
instance = recorder.get_instance(hass)
metadata_id = instance.states_meta_manager.get(entity_id, session)
metadata_id = instance.states_meta_manager.get(entity_id, session, False)
entity_id_to_metadata_id = {entity_id: metadata_id}
stmt = _state_changed_during_period_stmt(
start_time,
@ -426,7 +426,9 @@ def get_last_state_changes(
with session_scope(hass=hass, read_only=True) as session:
instance = recorder.get_instance(hass)
if not (metadata_id := instance.states_meta_manager.get(entity_id, session)):
if not (
metadata_id := instance.states_meta_manager.get(entity_id, session, False)
):
return {}
entity_id_to_metadata_id: dict[str, int | None] = {entity_id_lower: metadata_id}
stmt = _get_last_state_changes_stmt(number_of_states, metadata_id)

View File

@ -1457,7 +1457,9 @@ def migrate_entity_ids(instance: Recorder) -> bool:
with session_scope(session=instance.get_session()) as session:
if states := session.execute(find_entity_ids_to_migrate()).all():
entity_ids = {entity_id for _, entity_id in states}
entity_id_to_metadata_id = states_meta_manager.get_many(entity_ids, session)
entity_id_to_metadata_id = states_meta_manager.get_many(
entity_ids, session, True
)
if missing_entity_ids := {
# We should never see _EMPTY_ENTITY_ID in the states table
# but we need to be defensive so we don't fail the migration

View File

@ -47,7 +47,11 @@ class EventDataManager(BaseTableManager):
return None
def load(self, events: list[Event], session: Session) -> None:
"""Load the shared_datas to data_ids mapping into memory from events."""
"""Load the shared_datas to data_ids mapping into memory from events.
This call is not thread-safe and must be called from the
recorder thread.
"""
if hashes := {
EventData.hash_shared_data_bytes(shared_event_bytes)
for event in events
@ -56,17 +60,29 @@ class EventDataManager(BaseTableManager):
self._load_from_hashes(hashes, session)
def get(self, shared_data: str, data_hash: int, session: Session) -> int | None:
"""Resolve shared_datas to the data_id."""
"""Resolve shared_datas to the data_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self.get_many(((shared_data, data_hash),), session)[shared_data]
def get_from_cache(self, shared_data: str) -> int | None:
"""Resolve shared_data to the data_id without accessing the underlying database."""
"""Resolve shared_data to the data_id without accessing the underlying database.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._id_map.get(shared_data)
def get_many(
self, shared_data_data_hashs: Iterable[tuple[str, int]], session: Session
) -> dict[str, int | None]:
"""Resolve shared_datas to data_ids."""
"""Resolve shared_datas to data_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
missing_hashes: set[int] = set()
for shared_data, data_hash in shared_data_data_hashs:
@ -83,7 +99,11 @@ class EventDataManager(BaseTableManager):
def _load_from_hashes(
self, hashes: Iterable[int], session: Session
) -> dict[str, int | None]:
"""Load the shared_datas to data_ids mapping into memory from a list of hashes."""
"""Load the shared_datas to data_ids mapping into memory from a list of hashes.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
with session.no_autoflush:
for hashs_chunk in chunked(hashes, SQLITE_MAX_BIND_VARS):
@ -97,28 +117,48 @@ class EventDataManager(BaseTableManager):
return results
def get_pending(self, shared_data: str) -> EventData | None:
"""Get pending EventData that have not be assigned ids yet."""
"""Get pending EventData that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(shared_data)
def add_pending(self, db_event_data: EventData) -> None:
"""Add a pending EventData that will be committed at the next interval."""
"""Add a pending EventData that will be committed at the next interval.
This call is not thread-safe and must be called from the
recorder thread.
"""
assert db_event_data.shared_data is not None
shared_data: str = db_event_data.shared_data
self._pending[shared_data] = db_event_data
def post_commit_pending(self) -> None:
"""Call after commit to load the data_ids of the new EventData into the LRU."""
"""Call after commit to load the data_ids of the new EventData into the LRU.
This call is not thread-safe and must be called from the
recorder thread.
"""
for shared_data, db_event_data in self._pending.items():
self._id_map[shared_data] = db_event_data.data_id
self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed."""
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, data_ids: set[int]) -> None:
"""Evict purged data_ids from the cache when they are no longer used."""
"""Evict purged data_ids from the cache when they are no longer used.
This call is not thread-safe and must be called from the
recorder thread.
"""
id_map = self._id_map
event_data_ids_reversed = {
data_id: shared_data for shared_data, data_id in id_map.items()

View File

@ -32,20 +32,32 @@ class EventTypeManager(BaseTableManager):
super().__init__(recorder)
def load(self, events: list[Event], session: Session) -> None:
"""Load the event_type to event_type_ids mapping into memory."""
"""Load the event_type to event_type_ids mapping into memory.
This call is not thread-safe and must be called from the
recorder thread.
"""
self.get_many(
{event.event_type for event in events if event.event_type is not None},
session,
)
def get(self, event_type: str, session: Session) -> int | None:
"""Resolve event_type to the event_type_id."""
"""Resolve event_type to the event_type_id.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self.get_many((event_type,), session)[event_type]
def get_many(
self, event_types: Iterable[str], session: Session
) -> dict[str, int | None]:
"""Resolve event_types to event_type_ids."""
"""Resolve event_types to event_type_ids.
This call is not thread-safe and must be called from the
recorder thread.
"""
results: dict[str, int | None] = {}
missing: list[str] = []
for event_type in event_types:
@ -69,27 +81,47 @@ class EventTypeManager(BaseTableManager):
return results
def get_pending(self, event_type: str) -> EventTypes | None:
"""Get pending EventTypes that have not be assigned ids yet."""
"""Get pending EventTypes that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(event_type)
def add_pending(self, db_event_type: EventTypes) -> None:
"""Add a pending EventTypes that will be committed at the next interval."""
"""Add a pending EventTypes that will be committed at the next interval.
This call is not thread-safe and must be called from the
recorder thread.
"""
assert db_event_type.event_type is not None
event_type: str = db_event_type.event_type
self._pending[event_type] = db_event_type
def post_commit_pending(self) -> None:
"""Call after commit to load the event_type_ids of the new EventTypes into the LRU."""
"""Call after commit to load the event_type_ids of the new EventTypes into the LRU.
This call is not thread-safe and must be called from the
recorder thread.
"""
for event_type, db_event_types in self._pending.items():
self._id_map[event_type] = db_event_types.event_type_id
self._pending.clear()
def reset(self) -> None:
"""Reset the event manager after the database has been reset or changed."""
"""Reset the event manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, event_types: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used."""
"""Evict purged event_types from the cache when they are no longer used.
This call is not thread-safe and must be called from the
recorder thread.
"""
for event_type in event_types:
self._id_map.pop(event_type, None)

View File

@ -28,10 +28,16 @@ class StatesMetaManager(BaseTableManager):
"""Initialize the states meta manager."""
self._id_map: dict[str, int] = LRU(CACHE_SIZE)
self._pending: dict[str, StatesMeta] = {}
self._did_first_load = False
super().__init__(recorder)
def load(self, events: list[Event], session: Session) -> None:
"""Load the entity_id to metadata_id mapping into memory."""
"""Load the entity_id to metadata_id mapping into memory.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._did_first_load = True
self.get_many(
{
event.data["new_state"].entity_id
@ -39,21 +45,41 @@ class StatesMetaManager(BaseTableManager):
if event.data.get("new_state") is not None
},
session,
True,
)
def get(self, entity_id: str, session: Session) -> int | None:
"""Resolve entity_id to the metadata_id."""
return self.get_many((entity_id,), session)[entity_id]
def get(self, entity_id: str, session: Session, from_recorder: bool) -> int | None:
"""Resolve entity_id to the metadata_id.
This call is not thread-safe after startup since
purge can remove all references to an entity_id.
When calling this method from the recorder thread, set
from_recorder to True to ensure any missing entity_ids
are added to the cache.
"""
return self.get_many((entity_id,), session, from_recorder)[entity_id]
def get_metadata_id_to_entity_id(self, session: Session) -> dict[int, str]:
"""Resolve all entity_ids to metadata_ids."""
"""Resolve all entity_ids to metadata_ids.
This call is always thread-safe.
"""
with session.no_autoflush:
return dict(tuple(session.execute(find_all_states_metadata_ids()))) # type: ignore[arg-type]
def get_many(
self, entity_ids: Iterable[str], session: Session
self, entity_ids: Iterable[str], session: Session, from_recorder: bool
) -> dict[str, int | None]:
"""Resolve entity_id to metadata_id."""
"""Resolve entity_id to metadata_id.
This call is not thread-safe after startup since
purge can remove all references to an entity_id.
When calling this method from the recorder thread, set
from_recorder to True to ensure any missing entity_ids
are added to the cache.
"""
results: dict[str, int | None] = {}
missing: list[str] = []
for entity_id in entity_ids:
@ -65,39 +91,69 @@ class StatesMetaManager(BaseTableManager):
if not missing:
return results
# Only update the cache if we are in the recorder thread
# or the recorder event loop has not started yet since
# there is a chance that we could have just deleted all
# instances of an entity_id from the database via purge
# and we do not want to add it back to the cache from another
# thread (history query).
update_cache = from_recorder or not self._did_first_load
with session.no_autoflush:
for missing_chunk in chunked(missing, SQLITE_MAX_BIND_VARS):
for metadata_id, entity_id in session.execute(
find_states_metadata_ids(missing_chunk)
):
results[entity_id] = self._id_map[entity_id] = cast(
int, metadata_id
)
metadata_id = cast(int, metadata_id)
results[entity_id] = metadata_id
if update_cache:
self._id_map[entity_id] = metadata_id
return results
def get_pending(self, entity_id: str) -> StatesMeta | None:
"""Get pending StatesMeta that have not be assigned ids yet."""
"""Get pending StatesMeta that have not be assigned ids yet.
This call is not thread-safe and must be called from the
recorder thread.
"""
return self._pending.get(entity_id)
def add_pending(self, db_states_meta: StatesMeta) -> None:
"""Add a pending StatesMeta that will be committed at the next interval."""
"""Add a pending StatesMeta that will be committed at the next interval.
This call is not thread-safe and must be called from the
recorder thread.
"""
assert db_states_meta.entity_id is not None
entity_id: str = db_states_meta.entity_id
self._pending[entity_id] = db_states_meta
def post_commit_pending(self) -> None:
"""Call after commit to load the metadata_ids of the new StatesMeta into the LRU."""
"""Call after commit to load the metadata_ids of the new StatesMeta into the LRU.
This call is not thread-safe and must be called from the
recorder thread.
"""
for entity_id, db_states_meta in self._pending.items():
self._id_map[entity_id] = db_states_meta.metadata_id
self._pending.clear()
def reset(self) -> None:
"""Reset the states meta manager after the database has been reset or changed."""
"""Reset the states meta manager after the database has been reset or changed.
This call is not thread-safe and must be called from the
recorder thread.
"""
self._id_map.clear()
self._pending.clear()
def evict_purged(self, entity_ids: Iterable[str]) -> None:
"""Evict purged event_types from the cache when they are no longer used."""
"""Evict purged event_types from the cache when they are no longer used.
This call is not thread-safe and must be called from the
recorder thread.
"""
for entity_id in entity_ids:
self._id_map.pop(entity_id, None)

View File

@ -59,7 +59,7 @@ ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
def _get_native_states(hass, entity_id):
with session_scope(hass=hass) as session:
instance = recorder.get_instance(hass)
metadata_id = instance.states_meta_manager.get(entity_id, session)
metadata_id = instance.states_meta_manager.get(entity_id, session, True)
states = []
for dbstate in session.query(States).filter(States.metadata_id == metadata_id):
dbstate.entity_id = entity_id

View File

@ -684,7 +684,7 @@ def _convert_pending_states_to_meta(instance: Recorder, session: Session) -> Non
states.add(object)
entity_id_to_metadata_ids = instance.states_meta_manager.get_many(
entity_ids, session
entity_ids, session, True
)
for state in states:
@ -1974,6 +1974,7 @@ async def test_purge_old_states_purges_the_state_metadata_ids(
return instance.states_meta_manager.get_many(
["sensor.one", "sensor.two", "sensor.three", "sensor.unused"],
session,
True,
)
entity_id_to_metadata_id = await instance.async_add_executor_job(_insert_states)

View File

@ -908,7 +908,7 @@ def test_execute_stmt_lambda_element(
with session_scope(hass=hass) as session:
# No time window, we always get a list
metadata_id = instance.states_meta_manager.get("sensor.on", session)
metadata_id = instance.states_meta_manager.get("sensor.on", session, True)
stmt = _get_single_entity_states_stmt(dt_util.utcnow(), metadata_id, False)
rows = util.execute_stmt_lambda_element(session, stmt)
assert isinstance(rows, list)