Convert logbook to use lambda_stmt (#71624)
parent
68c2b63ca1
commit
26177bd080
|
@ -11,11 +11,13 @@ from typing import Any, cast
|
|||
|
||||
from aiohttp import web
|
||||
import sqlalchemy
|
||||
from sqlalchemy import lambda_stmt, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.lambdas import StatementLambdaElement
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import frontend
|
||||
|
@ -85,8 +87,6 @@ CONTINUOUS_ENTITY_ID_LIKE = [f"{domain}.%" for domain in CONTINUOUS_DOMAINS]
|
|||
|
||||
DOMAIN = "logbook"
|
||||
|
||||
GROUP_BY_MINUTES = 15
|
||||
|
||||
EMPTY_JSON_OBJECT = "{}"
|
||||
UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":'
|
||||
UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%"
|
||||
|
@ -435,70 +435,43 @@ def _get_events(
|
|||
|
||||
def yield_rows(query: Query) -> Generator[Row, None, None]:
|
||||
"""Yield Events that are not filtered away."""
|
||||
for row in query.yield_per(1000):
|
||||
if entity_ids or context_id:
|
||||
rows = query.all()
|
||||
else:
|
||||
rows = query.yield_per(1000)
|
||||
for row in rows:
|
||||
context_lookup.setdefault(row.context_id, row)
|
||||
if row.event_type != EVENT_CALL_SERVICE and (
|
||||
row.event_type == EVENT_STATE_CHANGED
|
||||
or _keep_row(hass, row, entities_filter)
|
||||
event_type = row.event_type
|
||||
if event_type != EVENT_CALL_SERVICE and (
|
||||
event_type == EVENT_STATE_CHANGED
|
||||
or _keep_row(hass, event_type, row, entities_filter)
|
||||
):
|
||||
yield row
|
||||
|
||||
if entity_ids is not None:
|
||||
entities_filter = generate_filter([], entity_ids, [], [])
|
||||
|
||||
event_types = [
|
||||
*ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED,
|
||||
*hass.data.get(DOMAIN, {}),
|
||||
]
|
||||
entity_filter = None
|
||||
if entity_ids is None and filters:
|
||||
entity_filter = filters.entity_filter() # type: ignore[no-untyped-call]
|
||||
stmt = _generate_logbook_query(
|
||||
start_day,
|
||||
end_day,
|
||||
event_types,
|
||||
entity_ids,
|
||||
entity_filter,
|
||||
entity_matches_only,
|
||||
context_id,
|
||||
)
|
||||
with session_scope(hass=hass) as session:
|
||||
old_state = aliased(States, name="old_state")
|
||||
query: Query
|
||||
query = _generate_events_query_without_states(session)
|
||||
query = _apply_event_time_filter(query, start_day, end_day)
|
||||
query = _apply_event_types_filter(
|
||||
hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED
|
||||
)
|
||||
|
||||
if entity_ids is not None:
|
||||
if entity_matches_only:
|
||||
# When entity_matches_only is provided, contexts and events that do not
|
||||
# contain the entity_ids are not included in the logbook response.
|
||||
query = _apply_event_entity_id_matchers(query, entity_ids)
|
||||
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
|
||||
query = query.union_all(
|
||||
_generate_states_query(
|
||||
session, start_day, end_day, old_state, entity_ids
|
||||
)
|
||||
)
|
||||
else:
|
||||
if context_id is not None:
|
||||
query = query.filter(Events.context_id == context_id)
|
||||
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
|
||||
|
||||
states_query = _generate_states_query(
|
||||
session, start_day, end_day, old_state, entity_ids
|
||||
)
|
||||
unions: list[Query] = []
|
||||
if context_id is not None:
|
||||
# Once all the old `state_changed` events
|
||||
# are gone from the database remove the
|
||||
# _generate_legacy_events_context_id_query
|
||||
unions.append(
|
||||
_generate_legacy_events_context_id_query(
|
||||
session, context_id, start_day, end_day
|
||||
)
|
||||
)
|
||||
states_query = states_query.outerjoin(
|
||||
Events, (States.event_id == Events.event_id)
|
||||
)
|
||||
states_query = states_query.filter(States.context_id == context_id)
|
||||
elif filters:
|
||||
states_query = states_query.filter(filters.entity_filter()) # type: ignore[no-untyped-call]
|
||||
unions.append(states_query)
|
||||
query = query.union_all(*unions)
|
||||
|
||||
query = query.order_by(Events.time_fired)
|
||||
|
||||
return list(
|
||||
_humanify(
|
||||
hass,
|
||||
yield_rows(query),
|
||||
yield_rows(session.execute(stmt)),
|
||||
entity_name_cache,
|
||||
event_cache,
|
||||
context_augmenter,
|
||||
|
@ -506,8 +479,72 @@ def _get_events(
|
|||
)
|
||||
|
||||
|
||||
def _generate_events_query_without_data(session: Session) -> Query:
|
||||
return session.query(
|
||||
def _generate_logbook_query(
|
||||
start_day: dt,
|
||||
end_day: dt,
|
||||
event_types: list[str],
|
||||
entity_ids: list[str] | None = None,
|
||||
entity_filter: Any | None = None,
|
||||
entity_matches_only: bool = False,
|
||||
context_id: str | None = None,
|
||||
) -> StatementLambdaElement:
|
||||
"""Generate a logbook query lambda_stmt."""
|
||||
stmt = lambda_stmt(
|
||||
lambda: _generate_events_query_without_states()
|
||||
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
|
||||
.where(Events.event_type.in_(event_types))
|
||||
.outerjoin(EventData, (Events.data_id == EventData.data_id))
|
||||
)
|
||||
if entity_ids is not None:
|
||||
if entity_matches_only:
|
||||
# When entity_matches_only is provided, contexts and events that do not
|
||||
# contain the entity_ids are not included in the logbook response.
|
||||
stmt.add_criteria(
|
||||
lambda s: s.where(_apply_event_entity_id_matchers(entity_ids)),
|
||||
track_on=entity_ids,
|
||||
)
|
||||
stmt += lambda s: s.union_all(
|
||||
_generate_states_query()
|
||||
.filter((States.last_updated > start_day) & (States.last_updated < end_day))
|
||||
.where(States.entity_id.in_(entity_ids))
|
||||
)
|
||||
else:
|
||||
if context_id is not None:
|
||||
# Once all the old `state_changed` events
|
||||
# are gone from the database remove the
|
||||
# union_all(_generate_legacy_events_context_id_query()....)
|
||||
stmt += lambda s: s.where(Events.context_id == context_id).union_all(
|
||||
_generate_legacy_events_context_id_query()
|
||||
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
|
||||
.where(Events.context_id == context_id),
|
||||
_generate_states_query()
|
||||
.where(
|
||||
(States.last_updated > start_day) & (States.last_updated < end_day)
|
||||
)
|
||||
.outerjoin(Events, (States.event_id == Events.event_id))
|
||||
.where(States.context_id == context_id),
|
||||
)
|
||||
elif entity_filter is not None:
|
||||
stmt += lambda s: s.union_all(
|
||||
_generate_states_query()
|
||||
.where(
|
||||
(States.last_updated > start_day) & (States.last_updated < end_day)
|
||||
)
|
||||
.where(entity_filter)
|
||||
)
|
||||
else:
|
||||
stmt += lambda s: s.union_all(
|
||||
_generate_states_query().where(
|
||||
(States.last_updated > start_day) & (States.last_updated < end_day)
|
||||
)
|
||||
)
|
||||
|
||||
stmt += lambda s: s.order_by(Events.time_fired)
|
||||
return stmt
|
||||
|
||||
|
||||
def _generate_events_query_without_data() -> Select:
|
||||
return select(
|
||||
literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"),
|
||||
literal(value=None, type_=sqlalchemy.Text).label("event_data"),
|
||||
States.last_changed.label("time_fired"),
|
||||
|
@ -519,65 +556,48 @@ def _generate_events_query_without_data(session: Session) -> Query:
|
|||
)
|
||||
|
||||
|
||||
def _generate_legacy_events_context_id_query(
|
||||
session: Session,
|
||||
context_id: str,
|
||||
start_day: dt,
|
||||
end_day: dt,
|
||||
) -> Query:
|
||||
def _generate_legacy_events_context_id_query() -> Select:
|
||||
"""Generate a legacy events context id query that also joins states."""
|
||||
# This can be removed once we no longer have event_ids in the states table
|
||||
legacy_context_id_query = session.query(
|
||||
*EVENT_COLUMNS,
|
||||
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
|
||||
States.state,
|
||||
States.entity_id,
|
||||
States.attributes,
|
||||
StateAttributes.shared_attrs,
|
||||
)
|
||||
legacy_context_id_query = _apply_event_time_filter(
|
||||
legacy_context_id_query, start_day, end_day
|
||||
)
|
||||
return (
|
||||
legacy_context_id_query.filter(Events.context_id == context_id)
|
||||
select(
|
||||
*EVENT_COLUMNS,
|
||||
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
|
||||
States.state,
|
||||
States.entity_id,
|
||||
States.attributes,
|
||||
StateAttributes.shared_attrs,
|
||||
)
|
||||
.outerjoin(States, (Events.event_id == States.event_id))
|
||||
.filter(States.last_updated == States.last_changed)
|
||||
.filter(_not_continuous_entity_matcher())
|
||||
.where(States.last_updated == States.last_changed)
|
||||
.where(_not_continuous_entity_matcher())
|
||||
.outerjoin(
|
||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _generate_events_query_without_states(session: Session) -> Query:
|
||||
return session.query(
|
||||
def _generate_events_query_without_states() -> Select:
|
||||
return select(
|
||||
*EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS
|
||||
)
|
||||
|
||||
|
||||
def _generate_states_query(
|
||||
session: Session,
|
||||
start_day: dt,
|
||||
end_day: dt,
|
||||
old_state: States,
|
||||
entity_ids: Iterable[str] | None,
|
||||
) -> Query:
|
||||
query = (
|
||||
_generate_events_query_without_data(session)
|
||||
def _generate_states_query() -> Select:
|
||||
old_state = aliased(States, name="old_state")
|
||||
return (
|
||||
_generate_events_query_without_data()
|
||||
.outerjoin(old_state, (States.old_state_id == old_state.state_id))
|
||||
.filter(_missing_state_matcher(old_state))
|
||||
.filter(_not_continuous_entity_matcher())
|
||||
.filter((States.last_updated > start_day) & (States.last_updated < end_day))
|
||||
.filter(States.last_updated == States.last_changed)
|
||||
)
|
||||
if entity_ids:
|
||||
query = query.filter(States.entity_id.in_(entity_ids))
|
||||
return query.outerjoin(
|
||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||
.where(_missing_state_matcher(old_state))
|
||||
.where(_not_continuous_entity_matcher())
|
||||
.where(States.last_updated == States.last_changed)
|
||||
.outerjoin(
|
||||
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _missing_state_matcher(old_state: States) -> Any:
|
||||
def _missing_state_matcher(old_state: States) -> sqlalchemy.and_:
|
||||
# The below removes state change events that do not have
|
||||
# and old_state or the old_state is missing (newly added entities)
|
||||
# or the new_state is missing (removed entities)
|
||||
|
@ -588,7 +608,7 @@ def _missing_state_matcher(old_state: States) -> Any:
|
|||
)
|
||||
|
||||
|
||||
def _not_continuous_entity_matcher() -> Any:
|
||||
def _not_continuous_entity_matcher() -> sqlalchemy.or_:
|
||||
"""Match non continuous entities."""
|
||||
return sqlalchemy.or_(
|
||||
_not_continuous_domain_matcher(),
|
||||
|
@ -598,7 +618,7 @@ def _not_continuous_entity_matcher() -> Any:
|
|||
)
|
||||
|
||||
|
||||
def _not_continuous_domain_matcher() -> Any:
|
||||
def _not_continuous_domain_matcher() -> sqlalchemy.and_:
|
||||
"""Match not continuous domains."""
|
||||
return sqlalchemy.and_(
|
||||
*[
|
||||
|
@ -608,7 +628,7 @@ def _not_continuous_domain_matcher() -> Any:
|
|||
).self_group()
|
||||
|
||||
|
||||
def _continuous_domain_matcher() -> Any:
|
||||
def _continuous_domain_matcher() -> sqlalchemy.or_:
|
||||
"""Match continuous domains."""
|
||||
return sqlalchemy.or_(
|
||||
*[
|
||||
|
@ -625,37 +645,22 @@ def _not_uom_attributes_matcher() -> Any:
|
|||
) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE)
|
||||
|
||||
|
||||
def _apply_event_time_filter(events_query: Query, start_day: dt, end_day: dt) -> Query:
|
||||
return events_query.filter(
|
||||
(Events.time_fired > start_day) & (Events.time_fired < end_day)
|
||||
)
|
||||
|
||||
|
||||
def _apply_event_types_filter(
|
||||
hass: HomeAssistant, query: Query, event_types: list[str]
|
||||
) -> Query:
|
||||
return query.filter(
|
||||
Events.event_type.in_(event_types + list(hass.data.get(DOMAIN, {})))
|
||||
)
|
||||
|
||||
|
||||
def _apply_event_entity_id_matchers(
|
||||
events_query: Query, entity_ids: Iterable[str]
|
||||
) -> Query:
|
||||
def _apply_event_entity_id_matchers(entity_ids: Iterable[str]) -> sqlalchemy.or_:
|
||||
"""Create matchers for the entity_id in the event_data."""
|
||||
ors = []
|
||||
for entity_id in entity_ids:
|
||||
like = ENTITY_ID_JSON_TEMPLATE.format(entity_id)
|
||||
ors.append(Events.event_data.like(like))
|
||||
ors.append(EventData.shared_data.like(like))
|
||||
return events_query.filter(sqlalchemy.or_(*ors))
|
||||
return sqlalchemy.or_(*ors)
|
||||
|
||||
|
||||
def _keep_row(
|
||||
hass: HomeAssistant,
|
||||
event_type: str,
|
||||
row: Row,
|
||||
entities_filter: EntityFilter | Callable[[str], bool] | None = None,
|
||||
) -> bool:
|
||||
event_type = row.event_type
|
||||
if event_type in HOMEASSISTANT_EVENTS:
|
||||
return entities_filter is None or entities_filter(HA_DOMAIN_ENTITY_ID)
|
||||
|
||||
|
|
|
@ -1390,6 +1390,59 @@ async def test_logbook_entity_matches_only(hass, hass_client, recorder_mock):
|
|||
assert json_dict[1]["context_user_id"] == "9400facee45711eaa9308bfd3d19e474"
|
||||
|
||||
|
||||
async def test_logbook_entity_matches_only_multiple_calls(
|
||||
hass, hass_client, recorder_mock
|
||||
):
|
||||
"""Test the logbook view with a single entity and entity_matches_only called multiple times."""
|
||||
await async_setup_component(hass, "logbook", {})
|
||||
await async_setup_component(hass, "automation", {})
|
||||
|
||||
await async_recorder_block_till_done(hass)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
for automation_id in range(5):
|
||||
hass.bus.async_fire(
|
||||
EVENT_AUTOMATION_TRIGGERED,
|
||||
{
|
||||
ATTR_NAME: f"Mock automation {automation_id}",
|
||||
ATTR_ENTITY_ID: f"automation.mock_{automation_id}_automation",
|
||||
},
|
||||
)
|
||||
await async_wait_recording_done(hass)
|
||||
client = await hass_client()
|
||||
|
||||
# Today time 00:00:00
|
||||
start = dt_util.utcnow().date()
|
||||
start_date = datetime(start.year, start.month, start.day)
|
||||
end_time = start + timedelta(hours=24)
|
||||
|
||||
for automation_id in range(5):
|
||||
# Test today entries with filter by end_time
|
||||
response = await client.get(
|
||||
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_{automation_id}_automation&entity_matches_only"
|
||||
)
|
||||
assert response.status == HTTPStatus.OK
|
||||
json_dict = await response.json()
|
||||
|
||||
assert len(json_dict) == 1
|
||||
assert (
|
||||
json_dict[0]["entity_id"] == f"automation.mock_{automation_id}_automation"
|
||||
)
|
||||
|
||||
response = await client.get(
|
||||
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_0_automation,automation.mock_1_automation,automation.mock_2_automation&entity_matches_only"
|
||||
)
|
||||
assert response.status == HTTPStatus.OK
|
||||
json_dict = await response.json()
|
||||
|
||||
assert len(json_dict) == 3
|
||||
assert json_dict[0]["entity_id"] == "automation.mock_0_automation"
|
||||
assert json_dict[1]["entity_id"] == "automation.mock_1_automation"
|
||||
assert json_dict[2]["entity_id"] == "automation.mock_2_automation"
|
||||
|
||||
|
||||
async def test_custom_log_entry_discoverable_via_entity_matches_only(
|
||||
hass, hass_client, recorder_mock
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue