Convert logbook to use lambda_stmt (#71624)

pull/71653/head
J. Nick Koston 2022-05-10 08:23:13 -05:00 committed by GitHub
parent 68c2b63ca1
commit 26177bd080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 179 additions and 121 deletions

View File

@ -11,11 +11,13 @@ from typing import Any, cast
from aiohttp import web
import sqlalchemy
from sqlalchemy import lambda_stmt, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import aliased
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Select
import voluptuous as vol
from homeassistant.components import frontend
@ -85,8 +87,6 @@ CONTINUOUS_ENTITY_ID_LIKE = [f"{domain}.%" for domain in CONTINUOUS_DOMAINS]
DOMAIN = "logbook"
GROUP_BY_MINUTES = 15
EMPTY_JSON_OBJECT = "{}"
UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":'
UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%"
@ -435,70 +435,43 @@ def _get_events(
def yield_rows(query: Query) -> Generator[Row, None, None]:
"""Yield Events that are not filtered away."""
for row in query.yield_per(1000):
if entity_ids or context_id:
rows = query.all()
else:
rows = query.yield_per(1000)
for row in rows:
context_lookup.setdefault(row.context_id, row)
if row.event_type != EVENT_CALL_SERVICE and (
row.event_type == EVENT_STATE_CHANGED
or _keep_row(hass, row, entities_filter)
event_type = row.event_type
if event_type != EVENT_CALL_SERVICE and (
event_type == EVENT_STATE_CHANGED
or _keep_row(hass, event_type, row, entities_filter)
):
yield row
if entity_ids is not None:
entities_filter = generate_filter([], entity_ids, [], [])
event_types = [
*ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED,
*hass.data.get(DOMAIN, {}),
]
entity_filter = None
if entity_ids is None and filters:
entity_filter = filters.entity_filter() # type: ignore[no-untyped-call]
stmt = _generate_logbook_query(
start_day,
end_day,
event_types,
entity_ids,
entity_filter,
entity_matches_only,
context_id,
)
with session_scope(hass=hass) as session:
old_state = aliased(States, name="old_state")
query: Query
query = _generate_events_query_without_states(session)
query = _apply_event_time_filter(query, start_day, end_day)
query = _apply_event_types_filter(
hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED
)
if entity_ids is not None:
if entity_matches_only:
# When entity_matches_only is provided, contexts and events that do not
# contain the entity_ids are not included in the logbook response.
query = _apply_event_entity_id_matchers(query, entity_ids)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
query = query.union_all(
_generate_states_query(
session, start_day, end_day, old_state, entity_ids
)
)
else:
if context_id is not None:
query = query.filter(Events.context_id == context_id)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
states_query = _generate_states_query(
session, start_day, end_day, old_state, entity_ids
)
unions: list[Query] = []
if context_id is not None:
# Once all the old `state_changed` events
# are gone from the database remove the
# _generate_legacy_events_context_id_query
unions.append(
_generate_legacy_events_context_id_query(
session, context_id, start_day, end_day
)
)
states_query = states_query.outerjoin(
Events, (States.event_id == Events.event_id)
)
states_query = states_query.filter(States.context_id == context_id)
elif filters:
states_query = states_query.filter(filters.entity_filter()) # type: ignore[no-untyped-call]
unions.append(states_query)
query = query.union_all(*unions)
query = query.order_by(Events.time_fired)
return list(
_humanify(
hass,
yield_rows(query),
yield_rows(session.execute(stmt)),
entity_name_cache,
event_cache,
context_augmenter,
@ -506,8 +479,72 @@ def _get_events(
)
def _generate_events_query_without_data(session: Session) -> Query:
return session.query(
def _generate_logbook_query(
start_day: dt,
end_day: dt,
event_types: list[str],
entity_ids: list[str] | None = None,
entity_filter: Any | None = None,
entity_matches_only: bool = False,
context_id: str | None = None,
) -> StatementLambdaElement:
"""Generate a logbook query lambda_stmt."""
stmt = lambda_stmt(
lambda: _generate_events_query_without_states()
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
.where(Events.event_type.in_(event_types))
.outerjoin(EventData, (Events.data_id == EventData.data_id))
)
if entity_ids is not None:
if entity_matches_only:
# When entity_matches_only is provided, contexts and events that do not
# contain the entity_ids are not included in the logbook response.
stmt.add_criteria(
lambda s: s.where(_apply_event_entity_id_matchers(entity_ids)),
track_on=entity_ids,
)
stmt += lambda s: s.union_all(
_generate_states_query()
.filter((States.last_updated > start_day) & (States.last_updated < end_day))
.where(States.entity_id.in_(entity_ids))
)
else:
if context_id is not None:
# Once all the old `state_changed` events
# are gone from the database remove the
# union_all(_generate_legacy_events_context_id_query()....)
stmt += lambda s: s.where(Events.context_id == context_id).union_all(
_generate_legacy_events_context_id_query()
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
.where(Events.context_id == context_id),
_generate_states_query()
.where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
.outerjoin(Events, (States.event_id == Events.event_id))
.where(States.context_id == context_id),
)
elif entity_filter is not None:
stmt += lambda s: s.union_all(
_generate_states_query()
.where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
.where(entity_filter)
)
else:
stmt += lambda s: s.union_all(
_generate_states_query().where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
)
stmt += lambda s: s.order_by(Events.time_fired)
return stmt
def _generate_events_query_without_data() -> Select:
return select(
literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"),
literal(value=None, type_=sqlalchemy.Text).label("event_data"),
States.last_changed.label("time_fired"),
@ -519,65 +556,48 @@ def _generate_events_query_without_data(session: Session) -> Query:
)
def _generate_legacy_events_context_id_query(
session: Session,
context_id: str,
start_day: dt,
end_day: dt,
) -> Query:
def _generate_legacy_events_context_id_query() -> Select:
"""Generate a legacy events context id query that also joins states."""
# This can be removed once we no longer have event_ids in the states table
legacy_context_id_query = session.query(
*EVENT_COLUMNS,
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
States.state,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
)
legacy_context_id_query = _apply_event_time_filter(
legacy_context_id_query, start_day, end_day
)
return (
legacy_context_id_query.filter(Events.context_id == context_id)
select(
*EVENT_COLUMNS,
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
States.state,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
)
.outerjoin(States, (Events.event_id == States.event_id))
.filter(States.last_updated == States.last_changed)
.filter(_not_continuous_entity_matcher())
.where(States.last_updated == States.last_changed)
.where(_not_continuous_entity_matcher())
.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
)
)
def _generate_events_query_without_states(session: Session) -> Query:
return session.query(
def _generate_events_query_without_states() -> Select:
return select(
*EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS
)
def _generate_states_query(
session: Session,
start_day: dt,
end_day: dt,
old_state: States,
entity_ids: Iterable[str] | None,
) -> Query:
query = (
_generate_events_query_without_data(session)
def _generate_states_query() -> Select:
old_state = aliased(States, name="old_state")
return (
_generate_events_query_without_data()
.outerjoin(old_state, (States.old_state_id == old_state.state_id))
.filter(_missing_state_matcher(old_state))
.filter(_not_continuous_entity_matcher())
.filter((States.last_updated > start_day) & (States.last_updated < end_day))
.filter(States.last_updated == States.last_changed)
)
if entity_ids:
query = query.filter(States.entity_id.in_(entity_ids))
return query.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
.where(_missing_state_matcher(old_state))
.where(_not_continuous_entity_matcher())
.where(States.last_updated == States.last_changed)
.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
)
)
def _missing_state_matcher(old_state: States) -> Any:
def _missing_state_matcher(old_state: States) -> sqlalchemy.and_:
# The below removes state change events that do not have
# and old_state or the old_state is missing (newly added entities)
# or the new_state is missing (removed entities)
@ -588,7 +608,7 @@ def _missing_state_matcher(old_state: States) -> Any:
)
def _not_continuous_entity_matcher() -> Any:
def _not_continuous_entity_matcher() -> sqlalchemy.or_:
"""Match non continuous entities."""
return sqlalchemy.or_(
_not_continuous_domain_matcher(),
@ -598,7 +618,7 @@ def _not_continuous_entity_matcher() -> Any:
)
def _not_continuous_domain_matcher() -> Any:
def _not_continuous_domain_matcher() -> sqlalchemy.and_:
"""Match not continuous domains."""
return sqlalchemy.and_(
*[
@ -608,7 +628,7 @@ def _not_continuous_domain_matcher() -> Any:
).self_group()
def _continuous_domain_matcher() -> Any:
def _continuous_domain_matcher() -> sqlalchemy.or_:
"""Match continuous domains."""
return sqlalchemy.or_(
*[
@ -625,37 +645,22 @@ def _not_uom_attributes_matcher() -> Any:
) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE)
def _apply_event_time_filter(events_query: Query, start_day: dt, end_day: dt) -> Query:
return events_query.filter(
(Events.time_fired > start_day) & (Events.time_fired < end_day)
)
def _apply_event_types_filter(
hass: HomeAssistant, query: Query, event_types: list[str]
) -> Query:
return query.filter(
Events.event_type.in_(event_types + list(hass.data.get(DOMAIN, {})))
)
def _apply_event_entity_id_matchers(
events_query: Query, entity_ids: Iterable[str]
) -> Query:
def _apply_event_entity_id_matchers(entity_ids: Iterable[str]) -> sqlalchemy.or_:
"""Create matchers for the entity_id in the event_data."""
ors = []
for entity_id in entity_ids:
like = ENTITY_ID_JSON_TEMPLATE.format(entity_id)
ors.append(Events.event_data.like(like))
ors.append(EventData.shared_data.like(like))
return events_query.filter(sqlalchemy.or_(*ors))
return sqlalchemy.or_(*ors)
def _keep_row(
hass: HomeAssistant,
event_type: str,
row: Row,
entities_filter: EntityFilter | Callable[[str], bool] | None = None,
) -> bool:
event_type = row.event_type
if event_type in HOMEASSISTANT_EVENTS:
return entities_filter is None or entities_filter(HA_DOMAIN_ENTITY_ID)

View File

@ -1390,6 +1390,59 @@ async def test_logbook_entity_matches_only(hass, hass_client, recorder_mock):
assert json_dict[1]["context_user_id"] == "9400facee45711eaa9308bfd3d19e474"
async def test_logbook_entity_matches_only_multiple_calls(
hass, hass_client, recorder_mock
):
"""Test the logbook view with a single entity and entity_matches_only called multiple times."""
await async_setup_component(hass, "logbook", {})
await async_setup_component(hass, "automation", {})
await async_recorder_block_till_done(hass)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
for automation_id in range(5):
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{
ATTR_NAME: f"Mock automation {automation_id}",
ATTR_ENTITY_ID: f"automation.mock_{automation_id}_automation",
},
)
await async_wait_recording_done(hass)
client = await hass_client()
# Today time 00:00:00
start = dt_util.utcnow().date()
start_date = datetime(start.year, start.month, start.day)
end_time = start + timedelta(hours=24)
for automation_id in range(5):
# Test today entries with filter by end_time
response = await client.get(
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_{automation_id}_automation&entity_matches_only"
)
assert response.status == HTTPStatus.OK
json_dict = await response.json()
assert len(json_dict) == 1
assert (
json_dict[0]["entity_id"] == f"automation.mock_{automation_id}_automation"
)
response = await client.get(
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_0_automation,automation.mock_1_automation,automation.mock_2_automation&entity_matches_only"
)
assert response.status == HTTPStatus.OK
json_dict = await response.json()
assert len(json_dict) == 3
assert json_dict[0]["entity_id"] == "automation.mock_0_automation"
assert json_dict[1]["entity_id"] == "automation.mock_1_automation"
assert json_dict[2]["entity_id"] == "automation.mock_2_automation"
async def test_custom_log_entry_discoverable_via_entity_matches_only(
hass, hass_client, recorder_mock
):