Fixes for logbook filtering and add it to the live stream (#72501)

pull/72533/head
J. Nick Koston 2022-05-25 15:17:08 -10:00 committed by GitHub
parent 1ac71455cb
commit bfa7693d18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 340 additions and 114 deletions

View File

@ -173,12 +173,6 @@ class EventProcessor:
self.filters,
self.context_id,
)
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Literal statement: %s",
stmt.compile(compile_kwargs={"literal_binds": True}),
)
with session_scope(hass=self.hass) as session:
return self.humanify(yield_rows(session.execute(stmt)))
@ -214,20 +208,16 @@ def _humanify(
include_entity_name = logbook_run.include_entity_name
format_time = logbook_run.format_time
def _keep_row(row: Row | EventAsRow, event_type: str) -> bool:
def _keep_row(row: EventAsRow) -> bool:
"""Check if the entity_filter rejects a row."""
assert entities_filter is not None
if entity_id := _row_event_data_extract(row, ENTITY_ID_JSON_EXTRACT):
if entity_id := row.entity_id:
return entities_filter(entity_id)
if event_type in external_events:
# If the entity_id isn't described, use the domain that describes
# the event for filtering.
domain: str | None = external_events[event_type][0]
else:
domain = _row_event_data_extract(row, DOMAIN_JSON_EXTRACT)
return domain is not None and entities_filter(f"{domain}._")
if entity_id := row.data.get(ATTR_ENTITY_ID):
return entities_filter(entity_id)
if domain := row.data.get(ATTR_DOMAIN):
return entities_filter(f"{domain}._")
return True
# Process rows
for row in rows:
@ -236,12 +226,12 @@ def _humanify(
continue
event_type = row.event_type
if event_type == EVENT_CALL_SERVICE or (
event_type is not PSUEDO_EVENT_STATE_CHANGED
and entities_filter is not None
and not _keep_row(row, event_type)
entities_filter
# We literally mean is EventAsRow not a subclass of EventAsRow
and type(row) is EventAsRow # pylint: disable=unidiomatic-typecheck
and not _keep_row(row)
):
continue
if event_type is PSUEDO_EVENT_STATE_CHANGED:
entity_id = row.entity_id
assert entity_id is not None

View File

@ -27,8 +27,16 @@ def statement_for_request(
# No entities: logbook sends everything for the timeframe
# limited by the context_id and the yaml configured filter
if not entity_ids and not device_ids:
entity_filter = filters.entity_filter() if filters else None
return all_stmt(start_day, end_day, event_types, entity_filter, context_id)
states_entity_filter = filters.states_entity_filter() if filters else None
events_entity_filter = filters.events_entity_filter() if filters else None
return all_stmt(
start_day,
end_day,
event_types,
states_entity_filter,
events_entity_filter,
context_id,
)
# sqlalchemy caches object quoting, the
# json quotable ones must be a different

View File

@ -22,7 +22,8 @@ def all_stmt(
start_day: dt,
end_day: dt,
event_types: tuple[str, ...],
entity_filter: ClauseList | None = None,
states_entity_filter: ClauseList | None = None,
events_entity_filter: ClauseList | None = None,
context_id: str | None = None,
) -> StatementLambdaElement:
"""Generate a logbook query for all entities."""
@ -37,12 +38,17 @@ def all_stmt(
_states_query_for_context_id(start_day, end_day, context_id),
legacy_select_events_context_id(start_day, end_day, context_id),
)
elif entity_filter is not None:
stmt += lambda s: s.union_all(
_states_query_for_all(start_day, end_day).where(entity_filter)
)
else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day))
if events_entity_filter is not None:
stmt += lambda s: s.where(events_entity_filter)
if states_entity_filter is not None:
stmt += lambda s: s.union_all(
_states_query_for_all(start_day, end_day).where(states_entity_filter)
)
else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day))
stmt += lambda s: s.order_by(Events.time_fired)
return stmt

View File

@ -1,22 +1,20 @@
"""Queries for logbook."""
from __future__ import annotations
from collections.abc import Callable
from datetime import datetime as dt
import json
from typing import Any
import sqlalchemy
from sqlalchemy import JSON, select, type_coerce
from sqlalchemy.orm import Query, aliased
from sqlalchemy import select
from sqlalchemy.orm import Query
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.selectable import Select
from homeassistant.components.proximity import DOMAIN as PROXIMITY_DOMAIN
from homeassistant.components.recorder.models import (
JSON_VARIENT_CAST,
JSONB_VARIENT_CAST,
OLD_FORMAT_ATTRS_JSON,
OLD_STATE,
SHARED_ATTRS_JSON,
EventData,
Events,
StateAttributes,
@ -30,36 +28,6 @@ CONTINUOUS_ENTITY_ID_LIKE = [f"{domain}.%" for domain in CONTINUOUS_DOMAINS]
UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":'
UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%"
OLD_STATE = aliased(States, name="old_state")
class JSONLiteral(JSON): # type: ignore[misc]
"""Teach SA how to literalize json."""
def literal_processor(self, dialect: str) -> Callable[[Any], str]:
"""Processor to convert a value to JSON."""
def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return process
EVENT_DATA_JSON = type_coerce(
EventData.shared_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True)
)
OLD_FORMAT_EVENT_DATA_JSON = type_coerce(
Events.event_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True)
)
SHARED_ATTRS_JSON = type_coerce(
StateAttributes.shared_attrs.cast(JSON_VARIENT_CAST), JSON(none_as_null=True)
)
OLD_FORMAT_ATTRS_JSON = type_coerce(
States.attributes.cast(JSON_VARIENT_CAST), JSON(none_as_null=True)
)
PSUEDO_EVENT_STATE_CHANGED = None
# Since we don't store event_types and None

View File

@ -4,24 +4,21 @@ from __future__ import annotations
from collections.abc import Iterable
from datetime import datetime as dt
from sqlalchemy import Column, lambda_stmt, select, union_all
from sqlalchemy import lambda_stmt, select, union_all
from sqlalchemy.orm import Query
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import CTE, CompoundSelect
from homeassistant.components.recorder.models import Events, States
from homeassistant.components.recorder.models import DEVICE_ID_IN_EVENT, Events, States
from .common import (
EVENT_DATA_JSON,
select_events_context_id_subquery,
select_events_context_only,
select_events_without_states,
select_states_context_only,
)
DEVICE_ID_IN_EVENT: Column = EVENT_DATA_JSON["device_id"]
def _select_device_id_context_ids_sub_query(
start_day: dt,

View File

@ -5,20 +5,20 @@ from collections.abc import Iterable
from datetime import datetime as dt
import sqlalchemy
from sqlalchemy import Column, lambda_stmt, select, union_all
from sqlalchemy import lambda_stmt, select, union_all
from sqlalchemy.orm import Query
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import CTE, CompoundSelect
from homeassistant.components.recorder.models import (
ENTITY_ID_IN_EVENT,
ENTITY_ID_LAST_UPDATED_INDEX,
OLD_ENTITY_ID_IN_EVENT,
Events,
States,
)
from .common import (
EVENT_DATA_JSON,
OLD_FORMAT_EVENT_DATA_JSON,
apply_states_filters,
select_events_context_id_subquery,
select_events_context_only,
@ -27,9 +27,6 @@ from .common import (
select_states_context_only,
)
ENTITY_ID_IN_EVENT: Column = EVENT_DATA_JSON["entity_id"]
OLD_ENTITY_ID_IN_EVENT: Column = OLD_FORMAT_EVENT_DATA_JSON["entity_id"]
def _select_entities_context_ids_sub_query(
start_day: dt,

View File

@ -1,14 +1,18 @@
"""Provide pre-made queries on top of the recorder component."""
from __future__ import annotations
from sqlalchemy import not_, or_
from collections.abc import Callable, Iterable
import json
from typing import Any
from sqlalchemy import Column, not_, or_
from sqlalchemy.sql.elements import ClauseList
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS
from homeassistant.helpers.typing import ConfigType
from .models import States
from .models import ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT, States
DOMAIN = "history"
HISTORY_FILTERS = "history_filters"
@ -59,50 +63,84 @@ class Filters:
or self.included_entity_globs
)
def entity_filter(self) -> ClauseList:
"""Generate the entity filter query."""
def _generate_filter_for_columns(
self, columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList:
includes = []
if self.included_domains:
includes.append(
or_(
*[
States.entity_id.like(f"{domain}.%")
for domain in self.included_domains
]
).self_group()
)
includes.append(_domain_matcher(self.included_domains, columns, encoder))
if self.included_entities:
includes.append(States.entity_id.in_(self.included_entities))
for glob in self.included_entity_globs:
includes.append(_glob_to_like(glob))
includes.append(_entity_matcher(self.included_entities, columns, encoder))
if self.included_entity_globs:
includes.append(
_globs_to_like(self.included_entity_globs, columns, encoder)
)
excludes = []
if self.excluded_domains:
excludes.append(
or_(
*[
States.entity_id.like(f"{domain}.%")
for domain in self.excluded_domains
]
).self_group()
)
excludes.append(_domain_matcher(self.excluded_domains, columns, encoder))
if self.excluded_entities:
excludes.append(States.entity_id.in_(self.excluded_entities))
for glob in self.excluded_entity_globs:
excludes.append(_glob_to_like(glob))
excludes.append(_entity_matcher(self.excluded_entities, columns, encoder))
if self.excluded_entity_globs:
excludes.append(
_globs_to_like(self.excluded_entity_globs, columns, encoder)
)
if not includes and not excludes:
return None
if includes and not excludes:
return or_(*includes)
return or_(*includes).self_group()
if not includes and excludes:
return not_(or_(*excludes))
return not_(or_(*excludes).self_group())
return or_(*includes) & not_(or_(*excludes))
return or_(*includes).self_group() & not_(or_(*excludes).self_group())
def states_entity_filter(self) -> ClauseList:
"""Generate the entity filter query."""
def _encoder(data: Any) -> Any:
"""Nothing to encode for states since there is no json."""
return data
return self._generate_filter_for_columns((States.entity_id,), _encoder)
def events_entity_filter(self) -> ClauseList:
"""Generate the entity filter query."""
_encoder = json.dumps
return or_(
(ENTITY_ID_IN_EVENT == _encoder(None))
& (OLD_ENTITY_ID_IN_EVENT == _encoder(None)),
self._generate_filter_for_columns(
(ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder
).self_group(),
)
def _glob_to_like(glob_str: str) -> ClauseList:
def _globs_to_like(
glob_strs: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList:
"""Translate glob to sql."""
return States.entity_id.like(glob_str.translate(GLOB_TO_SQL_CHARS))
return or_(
column.like(encoder(glob_str.translate(GLOB_TO_SQL_CHARS)))
for glob_str in glob_strs
for column in columns
)
def _entity_matcher(
entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList:
return or_(
column.in_([encoder(entity_id) for entity_id in entity_ids])
for column in columns
)
def _domain_matcher(
domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList:
return or_(
column.like(encoder(f"{domain}.%")) for domain in domains for column in columns
)

View File

@ -236,7 +236,7 @@ def _significant_states_stmt(
else:
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.entity_filter()
entity_filter = filters.states_entity_filter()
stmt += lambda q: q.filter(entity_filter)
stmt += lambda q: q.filter(States.last_updated > start_time)
@ -528,7 +528,7 @@ def _get_states_for_all_stmt(
)
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.entity_filter()
entity_filter = filters.states_entity_filter()
stmt += lambda q: q.filter(entity_filter)
if join_attributes:
stmt += lambda q: q.outerjoin(

View File

@ -1,6 +1,7 @@
"""Models for SQLAlchemy."""
from __future__ import annotations
from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
@ -9,6 +10,7 @@ from typing import Any, TypedDict, cast, overload
import ciso8601
from fnvhash import fnv1a_32
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
@ -22,11 +24,12 @@ from sqlalchemy import (
String,
Text,
distinct,
type_coerce,
)
from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm import aliased, declarative_base, relationship
from sqlalchemy.orm.session import Session
from homeassistant.components.websocket_api.const import (
@ -119,6 +122,21 @@ DOUBLE_TYPE = (
.with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
)
class JSONLiteral(JSON): # type: ignore[misc]
"""Teach SA how to literalize json."""
def literal_processor(self, dialect: str) -> Callable[[Any], str]:
"""Processor to convert a value to JSON."""
def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return process
EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote]
EVENT_ORIGIN_TO_IDX = {origin: idx for idx, origin in enumerate(EVENT_ORIGIN_ORDER)}
@ -612,6 +630,26 @@ class StatisticsRuns(Base): # type: ignore[misc,valid-type]
)
EVENT_DATA_JSON = type_coerce(
EventData.shared_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True)
)
OLD_FORMAT_EVENT_DATA_JSON = type_coerce(
Events.event_data.cast(JSONB_VARIENT_CAST), JSONLiteral(none_as_null=True)
)
SHARED_ATTRS_JSON = type_coerce(
StateAttributes.shared_attrs.cast(JSON_VARIENT_CAST), JSON(none_as_null=True)
)
OLD_FORMAT_ATTRS_JSON = type_coerce(
States.attributes.cast(JSON_VARIENT_CAST), JSON(none_as_null=True)
)
ENTITY_ID_IN_EVENT: Column = EVENT_DATA_JSON["entity_id"]
OLD_ENTITY_ID_IN_EVENT: Column = OLD_FORMAT_EVENT_DATA_JSON["entity_id"]
DEVICE_ID_IN_EVENT: Column = EVENT_DATA_JSON["device_id"]
OLD_STATE = aliased(States, name="old_state")
@overload
def process_timestamp(ts: None) -> None:
...

View File

@ -510,7 +510,7 @@ async def test_exclude_described_event(hass, hass_client, recorder_mock):
return {
"name": "Test Name",
"message": "tested a message",
"entity_id": event.data.get(ATTR_ENTITY_ID),
"entity_id": event.data[ATTR_ENTITY_ID],
}
def async_describe_events(hass, async_describe_event):
@ -2003,13 +2003,12 @@ async def test_include_events_domain_glob(hass, hass_client, recorder_mock):
)
await async_recorder_block_till_done(hass)
# Should get excluded by domain
hass.bus.async_fire(
logbook.EVENT_LOGBOOK_ENTRY,
{
logbook.ATTR_NAME: "Alarm",
logbook.ATTR_MESSAGE: "is triggered",
logbook.ATTR_DOMAIN: "switch",
logbook.ATTR_ENTITY_ID: "switch.any",
},
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)

View File

@ -14,16 +14,21 @@ from homeassistant.components.logbook import websocket_api
from homeassistant.components.script import EVENT_SCRIPT_STARTED
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_FRIENDLY_NAME,
ATTR_NAME,
ATTR_UNIT_OF_MEASUREMENT,
CONF_DOMAINS,
CONF_ENTITIES,
CONF_EXCLUDE,
EVENT_HOMEASSISTANT_START,
STATE_OFF,
STATE_ON,
)
from homeassistant.core import Event, HomeAssistant, State
from homeassistant.helpers import device_registry
from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@ -457,6 +462,186 @@ async def test_get_events_with_device_ids(hass, hass_ws_client, recorder_mock):
assert isinstance(results[3]["when"], float)
@patch("homeassistant.components.logbook.websocket_api.EVENT_COALESCE_TIME", 0)
async def test_subscribe_unsubscribe_logbook_stream_excluded_entities(
hass, recorder_mock, hass_ws_client
):
"""Test subscribe/unsubscribe logbook stream with excluded entities."""
now = dt_util.utcnow()
await asyncio.gather(
*[
async_setup_component(hass, comp, {})
for comp in ("homeassistant", "automation", "script")
]
)
await async_setup_component(
hass,
logbook.DOMAIN,
{
logbook.DOMAIN: {
CONF_EXCLUDE: {
CONF_ENTITIES: ["light.exc"],
CONF_DOMAINS: ["switch"],
CONF_ENTITY_GLOBS: "*.excluded",
}
},
},
)
await hass.async_block_till_done()
init_count = sum(hass.bus.async_listeners().values())
hass.states.async_set("light.exc", STATE_ON)
hass.states.async_set("light.exc", STATE_OFF)
hass.states.async_set("switch.any", STATE_ON)
hass.states.async_set("switch.any", STATE_OFF)
hass.states.async_set("cover.excluded", STATE_ON)
hass.states.async_set("cover.excluded", STATE_OFF)
hass.states.async_set("binary_sensor.is_light", STATE_ON)
hass.states.async_set("binary_sensor.is_light", STATE_OFF)
state: State = hass.states.get("binary_sensor.is_light")
await hass.async_block_till_done()
await async_wait_recording_done(hass)
websocket_client = await hass_ws_client()
await websocket_client.send_json(
{"id": 7, "type": "logbook/event_stream", "start_time": now.isoformat()}
)
msg = await asyncio.wait_for(websocket_client.receive_json(), 2)
assert msg["id"] == 7
assert msg["type"] == TYPE_RESULT
assert msg["success"]
msg = await asyncio.wait_for(websocket_client.receive_json(), 2)
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"]["events"] == [
{
"entity_id": "binary_sensor.is_light",
"state": "off",
"when": state.last_updated.timestamp(),
}
]
assert msg["event"]["start_time"] == now.timestamp()
assert msg["event"]["end_time"] > msg["event"]["start_time"]
assert msg["event"]["partial"] is True
hass.states.async_set("light.exc", STATE_ON)
hass.states.async_set("light.exc", STATE_OFF)
hass.states.async_set("switch.any", STATE_ON)
hass.states.async_set("switch.any", STATE_OFF)
hass.states.async_set("cover.excluded", STATE_ON)
hass.states.async_set("cover.excluded", STATE_OFF)
hass.states.async_set("light.alpha", "on")
hass.states.async_set("light.alpha", "off")
alpha_off_state: State = hass.states.get("light.alpha")
hass.states.async_set("light.zulu", "on", {"color": "blue"})
hass.states.async_set("light.zulu", "off", {"effect": "help"})
zulu_off_state: State = hass.states.get("light.zulu")
hass.states.async_set(
"light.zulu", "on", {"effect": "help", "color": ["blue", "green"]}
)
zulu_on_state: State = hass.states.get("light.zulu")
await hass.async_block_till_done()
hass.states.async_remove("light.zulu")
await hass.async_block_till_done()
hass.states.async_set("light.zulu", "on", {"effect": "help", "color": "blue"})
msg = await asyncio.wait_for(websocket_client.receive_json(), 2)
assert msg["id"] == 7
assert msg["type"] == "event"
assert "partial" not in msg["event"]["events"]
assert msg["event"]["events"] == []
msg = await asyncio.wait_for(websocket_client.receive_json(), 2)
assert msg["id"] == 7
assert msg["type"] == "event"
assert "partial" not in msg["event"]["events"]
assert msg["event"]["events"] == [
{
"entity_id": "light.alpha",
"state": "off",
"when": alpha_off_state.last_updated.timestamp(),
},
{
"entity_id": "light.zulu",
"state": "off",
"when": zulu_off_state.last_updated.timestamp(),
},
{
"entity_id": "light.zulu",
"state": "on",
"when": zulu_on_state.last_updated.timestamp(),
},
]
await async_wait_recording_done(hass)
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{ATTR_NAME: "Mock automation 3", ATTR_ENTITY_ID: "cover.excluded"},
)
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{
ATTR_NAME: "Mock automation switch matching entity",
ATTR_ENTITY_ID: "switch.match_domain",
},
)
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{ATTR_NAME: "Mock automation switch matching domain", ATTR_DOMAIN: "switch"},
)
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{ATTR_NAME: "Mock automation matches nothing"},
)
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{ATTR_NAME: "Mock automation 3", ATTR_ENTITY_ID: "light.keep"},
)
hass.states.async_set("cover.excluded", STATE_ON)
hass.states.async_set("cover.excluded", STATE_OFF)
await hass.async_block_till_done()
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == "event"
assert msg["event"]["events"] == [
{
"context_id": ANY,
"domain": "automation",
"entity_id": None,
"message": "triggered",
"name": "Mock automation matches nothing",
"source": None,
"when": ANY,
},
{
"context_id": ANY,
"domain": "automation",
"entity_id": "light.keep",
"message": "triggered",
"name": "Mock automation 3",
"source": None,
"when": ANY,
},
]
await websocket_client.send_json(
{"id": 8, "type": "unsubscribe_events", "subscription": 7}
)
msg = await asyncio.wait_for(websocket_client.receive_json(), 2)
assert msg["id"] == 8
assert msg["type"] == TYPE_RESULT
assert msg["success"]
# Check our listener got unsubscribed
assert sum(hass.bus.async_listeners().values()) == init_count
@patch("homeassistant.components.logbook.websocket_api.EVENT_COALESCE_TIME", 0)
async def test_subscribe_unsubscribe_logbook_stream(
hass, recorder_mock, hass_ws_client