Small cleanup to logbook context augmenter (#72043)

pull/72052/head
J. Nick Koston 2022-05-17 23:10:28 -05:00 committed by GitHub
parent 1d6659224f
commit ec01e00184
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 9 deletions

View File

@ -393,7 +393,11 @@ def _humanify(
context_lookup: dict[str | None, Row | None] = {None: None} context_lookup: dict[str | None, Row | None] = {None: None}
event_cache = EventCache(event_data_cache) event_cache = EventCache(event_data_cache)
context_augmenter = ContextAugmenter( context_augmenter = ContextAugmenter(
context_lookup, entity_name_cache, external_events, event_cache context_lookup,
entity_name_cache,
external_events,
event_cache,
include_entity_name,
) )
def _keep_row(row: Row, event_type: str) -> bool: def _keep_row(row: Row, event_type: str) -> bool:
@ -447,7 +451,7 @@ def _humanify(
if icon := row.icon or row.old_format_icon: if icon := row.icon or row.old_format_icon:
data[LOGBOOK_ENTRY_ICON] = icon data[LOGBOOK_ENTRY_ICON] = icon
context_augmenter.augment(data, row, context_id, include_entity_name) context_augmenter.augment(data, row, context_id)
yield data yield data
elif event_type in external_events: elif event_type in external_events:
@ -455,7 +459,7 @@ def _humanify(
data = describe_event(event_cache.get(row)) data = describe_event(event_cache.get(row))
data[LOGBOOK_ENTRY_WHEN] = format_time(row) data[LOGBOOK_ENTRY_WHEN] = format_time(row)
data[LOGBOOK_ENTRY_DOMAIN] = domain data[LOGBOOK_ENTRY_DOMAIN] = domain
context_augmenter.augment(data, row, context_id, include_entity_name) context_augmenter.augment(data, row, context_id)
yield data yield data
elif event_type == EVENT_LOGBOOK_ENTRY: elif event_type == EVENT_LOGBOOK_ENTRY:
@ -475,7 +479,7 @@ def _humanify(
LOGBOOK_ENTRY_DOMAIN: entry_domain, LOGBOOK_ENTRY_DOMAIN: entry_domain,
LOGBOOK_ENTRY_ENTITY_ID: entry_entity_id, LOGBOOK_ENTRY_ENTITY_ID: entry_entity_id,
} }
context_augmenter.augment(data, row, context_id, include_entity_name) context_augmenter.augment(data, row, context_id)
yield data yield data
@ -558,16 +562,16 @@ class ContextAugmenter:
str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]]
], ],
event_cache: EventCache, event_cache: EventCache,
include_entity_name: bool,
) -> None: ) -> None:
"""Init the augmenter.""" """Init the augmenter."""
self.context_lookup = context_lookup self.context_lookup = context_lookup
self.entity_name_cache = entity_name_cache self.entity_name_cache = entity_name_cache
self.external_events = external_events self.external_events = external_events
self.event_cache = event_cache self.event_cache = event_cache
self.include_entity_name = include_entity_name
def augment( def augment(self, data: dict[str, Any], row: Row, context_id: str) -> None:
self, data: dict[str, Any], row: Row, context_id: str, include_entity_name: bool
) -> None:
"""Augment data from the row and cache.""" """Augment data from the row and cache."""
if context_user_id := row.context_user_id: if context_user_id := row.context_user_id:
data[CONTEXT_USER_ID] = context_user_id data[CONTEXT_USER_ID] = context_user_id
@ -594,7 +598,7 @@ class ContextAugmenter:
# State change # State change
if context_entity_id := context_row.entity_id: if context_entity_id := context_row.entity_id:
data[CONTEXT_ENTITY_ID] = context_entity_id data[CONTEXT_ENTITY_ID] = context_entity_id
if include_entity_name: if self.include_entity_name:
data[CONTEXT_ENTITY_ID_NAME] = self.entity_name_cache.get( data[CONTEXT_ENTITY_ID_NAME] = self.entity_name_cache.get(
context_entity_id, context_row context_entity_id, context_row
) )
@ -625,7 +629,7 @@ class ContextAugmenter:
if not (attr_entity_id := described.get(ATTR_ENTITY_ID)): if not (attr_entity_id := described.get(ATTR_ENTITY_ID)):
return return
data[CONTEXT_ENTITY_ID] = attr_entity_id data[CONTEXT_ENTITY_ID] = attr_entity_id
if include_entity_name: if self.include_entity_name:
data[CONTEXT_ENTITY_ID_NAME] = self.entity_name_cache.get( data[CONTEXT_ENTITY_ID_NAME] = self.entity_name_cache.get(
attr_entity_id, context_row attr_entity_id, context_row
) )