Fix race in tracking pending writes in recorder (#93414)
parent
41702410f7
commit
63b81d86ef
|
@ -215,6 +215,7 @@ class Recorder(threading.Thread):
|
|||
|
||||
self.schema_version = 0
|
||||
self._commits_without_expire = 0
|
||||
self._event_session_has_pending_writes = False
|
||||
|
||||
self.recorder_runs_manager = RecorderRunsManager()
|
||||
self.states_manager = StatesManager()
|
||||
|
@ -322,7 +323,7 @@ class Recorder(threading.Thread):
|
|||
if (
|
||||
self._event_listener
|
||||
and not self._database_lock_task
|
||||
and self._event_session_has_pending_writes()
|
||||
and self._event_session_has_pending_writes
|
||||
):
|
||||
self.queue_task(COMMIT_TASK)
|
||||
|
||||
|
@ -688,6 +689,11 @@ class Recorder(threading.Thread):
|
|||
# anything goes wrong in the run loop
|
||||
self._shutdown()
|
||||
|
||||
def _add_to_session(self, session: Session, obj: object) -> None:
|
||||
"""Add an object to the session."""
|
||||
self._event_session_has_pending_writes = True
|
||||
session.add(obj)
|
||||
|
||||
def _run(self) -> None:
|
||||
"""Start processing events to save."""
|
||||
self.thread_id = threading.get_ident()
|
||||
|
@ -1016,11 +1022,11 @@ class Recorder(threading.Thread):
|
|||
else:
|
||||
event_types = EventTypes(event_type=event.event_type)
|
||||
event_type_manager.add_pending(event_types)
|
||||
session.add(event_types)
|
||||
self._add_to_session(session, event_types)
|
||||
dbevent.event_type_rel = event_types
|
||||
|
||||
if not event.data:
|
||||
session.add(dbevent)
|
||||
self._add_to_session(session, dbevent)
|
||||
return
|
||||
|
||||
event_data_manager = self.event_data_manager
|
||||
|
@ -1042,10 +1048,10 @@ class Recorder(threading.Thread):
|
|||
# No matching attributes found, save them in the DB
|
||||
dbevent_data = EventData(shared_data=shared_data, hash=hash_)
|
||||
event_data_manager.add_pending(dbevent_data)
|
||||
session.add(dbevent_data)
|
||||
self._add_to_session(session, dbevent_data)
|
||||
dbevent.event_data_rel = dbevent_data
|
||||
|
||||
session.add(dbevent)
|
||||
self._add_to_session(session, dbevent)
|
||||
|
||||
def _process_state_changed_event_into_session(self, event: Event) -> None:
|
||||
"""Process a state_changed event into the session."""
|
||||
|
@ -1090,7 +1096,7 @@ class Recorder(threading.Thread):
|
|||
else:
|
||||
states_meta = StatesMeta(entity_id=entity_id)
|
||||
states_meta_manager.add_pending(states_meta)
|
||||
session.add(states_meta)
|
||||
self._add_to_session(session, states_meta)
|
||||
dbstate.states_meta_rel = states_meta
|
||||
|
||||
# Map the event data to the StateAttributes table
|
||||
|
@ -1115,10 +1121,10 @@ class Recorder(threading.Thread):
|
|||
# No matching attributes found, save them in the DB
|
||||
dbstate_attributes = StateAttributes(shared_attrs=shared_attrs, hash=hash_)
|
||||
state_attributes_manager.add_pending(dbstate_attributes)
|
||||
session.add(dbstate_attributes)
|
||||
self._add_to_session(session, dbstate_attributes)
|
||||
dbstate.state_attributes = dbstate_attributes
|
||||
|
||||
session.add(dbstate)
|
||||
self._add_to_session(session, dbstate)
|
||||
|
||||
def _handle_database_error(self, err: Exception) -> bool:
|
||||
"""Handle a database error that may result in moving away the corrupt db."""
|
||||
|
@ -1130,14 +1136,9 @@ class Recorder(threading.Thread):
|
|||
return True
|
||||
return False
|
||||
|
||||
def _event_session_has_pending_writes(self) -> bool:
|
||||
"""Return True if there are pending writes in the event session."""
|
||||
session = self.event_session
|
||||
return bool(session and (session.new or session.dirty))
|
||||
|
||||
def _commit_event_session_or_retry(self) -> None:
|
||||
"""Commit the event session if there is work to do."""
|
||||
if not self._event_session_has_pending_writes():
|
||||
if not self._event_session_has_pending_writes:
|
||||
return
|
||||
tries = 1
|
||||
while tries <= self.db_max_retries:
|
||||
|
@ -1163,6 +1164,7 @@ class Recorder(threading.Thread):
|
|||
self._commits_without_expire += 1
|
||||
|
||||
session.commit()
|
||||
self._event_session_has_pending_writes = False
|
||||
# We just committed the state attributes to the database
|
||||
# and we now know the attributes_ids. We can save
|
||||
# many selects for matching attributes by loading them
|
||||
|
@ -1263,7 +1265,7 @@ class Recorder(threading.Thread):
|
|||
|
||||
async def async_block_till_done(self) -> None:
|
||||
"""Async version of block_till_done."""
|
||||
if self._queue.empty() and not self._event_session_has_pending_writes():
|
||||
if self._queue.empty() and not self._event_session_has_pending_writes:
|
||||
return
|
||||
event = asyncio.Event()
|
||||
self.queue_task(SynchronizeTask(event))
|
||||
|
@ -1417,6 +1419,8 @@ class Recorder(threading.Thread):
|
|||
if self.event_session is None:
|
||||
return
|
||||
if self.recorder_runs_manager.active:
|
||||
# .end will add to the event session
|
||||
self._event_session_has_pending_writes = True
|
||||
self.recorder_runs_manager.end(self.event_session)
|
||||
try:
|
||||
self._commit_event_session_or_retry()
|
||||
|
|
Loading…
Reference in New Issue