Fix race in tracking pending writes in recorder (#93414)

pull/93422/head
J. Nick Koston 2023-05-23 14:47:31 -05:00 committed by Franck Nijhof
parent 41702410f7
commit 63b81d86ef
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
1 changed files with 19 additions and 15 deletions

View File

@ -215,6 +215,7 @@ class Recorder(threading.Thread):
self.schema_version = 0
self._commits_without_expire = 0
self._event_session_has_pending_writes = False
self.recorder_runs_manager = RecorderRunsManager()
self.states_manager = StatesManager()
@ -322,7 +323,7 @@ class Recorder(threading.Thread):
if (
self._event_listener
and not self._database_lock_task
and self._event_session_has_pending_writes()
and self._event_session_has_pending_writes
):
self.queue_task(COMMIT_TASK)
@ -688,6 +689,11 @@ class Recorder(threading.Thread):
# anything goes wrong in the run loop
self._shutdown()
def _add_to_session(self, session: Session, obj: object) -> None:
"""Add an object to the session."""
self._event_session_has_pending_writes = True
session.add(obj)
def _run(self) -> None:
"""Start processing events to save."""
self.thread_id = threading.get_ident()
@ -1016,11 +1022,11 @@ class Recorder(threading.Thread):
else:
event_types = EventTypes(event_type=event.event_type)
event_type_manager.add_pending(event_types)
session.add(event_types)
self._add_to_session(session, event_types)
dbevent.event_type_rel = event_types
if not event.data:
session.add(dbevent)
self._add_to_session(session, dbevent)
return
event_data_manager = self.event_data_manager
@ -1042,10 +1048,10 @@ class Recorder(threading.Thread):
# No matching attributes found, save them in the DB
dbevent_data = EventData(shared_data=shared_data, hash=hash_)
event_data_manager.add_pending(dbevent_data)
session.add(dbevent_data)
self._add_to_session(session, dbevent_data)
dbevent.event_data_rel = dbevent_data
session.add(dbevent)
self._add_to_session(session, dbevent)
def _process_state_changed_event_into_session(self, event: Event) -> None:
"""Process a state_changed event into the session."""
@ -1090,7 +1096,7 @@ class Recorder(threading.Thread):
else:
states_meta = StatesMeta(entity_id=entity_id)
states_meta_manager.add_pending(states_meta)
session.add(states_meta)
self._add_to_session(session, states_meta)
dbstate.states_meta_rel = states_meta
# Map the event data to the StateAttributes table
@ -1115,10 +1121,10 @@ class Recorder(threading.Thread):
# No matching attributes found, save them in the DB
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
state_attributes_manager.add_pending(dbstate_attributes)
session.add(dbstate_attributes)
self._add_to_session(session, dbstate_attributes)
dbstate.state_attributes = dbstate_attributes
session.add(dbstate)
self._add_to_session(session, dbstate)
def _handle_database_error(self, err: Exception) -> bool:
"""Handle a database error that may result in moving away the corrupt db."""
@ -1130,14 +1136,9 @@ class Recorder(threading.Thread):
return True
return False
def _event_session_has_pending_writes(self) -> bool:
"""Return True if there are pending writes in the event session."""
session = self.event_session
return bool(session and (session.new or session.dirty))
def _commit_event_session_or_retry(self) -> None:
"""Commit the event session if there is work to do."""
if not self._event_session_has_pending_writes():
if not self._event_session_has_pending_writes:
return
tries = 1
while tries <= self.db_max_retries:
@ -1163,6 +1164,7 @@ class Recorder(threading.Thread):
self._commits_without_expire += 1
session.commit()
self._event_session_has_pending_writes = False
# We just committed the state attributes to the database
# and we now know the attributes_ids. We can save
# many selects for matching attributes by loading them
@ -1263,7 +1265,7 @@ class Recorder(threading.Thread):
async def async_block_till_done(self) -> None:
"""Async version of block_till_done."""
if self._queue.empty() and not self._event_session_has_pending_writes():
if self._queue.empty() and not self._event_session_has_pending_writes:
return
event = asyncio.Event()
self.queue_task(SynchronizeTask(event))
@ -1417,6 +1419,8 @@ class Recorder(threading.Thread):
if self.event_session is None:
return
if self.recorder_runs_manager.active:
# .end will add to the event session
self._event_session_has_pending_writes = True
self.recorder_runs_manager.end(self.event_session)
try:
self._commit_event_session_or_retry()