Relocate sqlalchemy filter builder to recorder/filters.py (#71883)

pull/71900/head
J. Nick Koston 2022-05-15 01:04:23 -05:00 committed by GitHub
parent 65f44bd80b
commit 1f753ecd88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 177 additions and 183 deletions

View File

@ -6,20 +6,17 @@ from datetime import datetime as dt, timedelta
from http import HTTPStatus
import logging
import time
from typing import Any, Literal, cast
from typing import Literal, cast
from aiohttp import web
from sqlalchemy import not_, or_
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.orm import Query
import voluptuous as vol
from homeassistant.components import frontend, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import (
get_instance,
history,
models as history_models,
from homeassistant.components.recorder import get_instance, history
from homeassistant.components.recorder.filters import (
Filters,
sqlalchemy_filter_from_include_exclude_conf,
)
from homeassistant.components.recorder.statistics import (
list_statistic_ids,
@ -28,13 +25,9 @@ from homeassistant.components.recorder.statistics import (
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import (
CONF_ENTITY_GLOBS,
INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA,
)
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util
@ -46,10 +39,6 @@ HISTORY_USE_INCLUDE_ORDER = "history_use_include_order"
CONF_ORDER = "use_include_order"
GLOB_TO_SQL_CHARS = {
42: "%", # *
46: "_", # .
}
CONFIG_SCHEMA = vol.Schema(
{
@ -410,112 +399,6 @@ class HistoryPeriodView(HomeAssistantView):
return self.json(sorted_result)
def sqlalchemy_filter_from_include_exclude_conf(conf: ConfigType) -> Filters | None:
"""Build a sql filter from config."""
filters = Filters()
if exclude := conf.get(CONF_EXCLUDE):
filters.excluded_entities = exclude.get(CONF_ENTITIES, [])
filters.excluded_domains = exclude.get(CONF_DOMAINS, [])
filters.excluded_entity_globs = exclude.get(CONF_ENTITY_GLOBS, [])
if include := conf.get(CONF_INCLUDE):
filters.included_entities = include.get(CONF_ENTITIES, [])
filters.included_domains = include.get(CONF_DOMAINS, [])
filters.included_entity_globs = include.get(CONF_ENTITY_GLOBS, [])
return filters if filters.has_config else None
class Filters:
"""Container for the configured include and exclude filters."""
def __init__(self) -> None:
"""Initialise the include and exclude filters."""
self.excluded_entities: list[str] = []
self.excluded_domains: list[str] = []
self.excluded_entity_globs: list[str] = []
self.included_entities: list[str] = []
self.included_domains: list[str] = []
self.included_entity_globs: list[str] = []
def apply(self, query: Query) -> Query:
"""Apply the entity filter."""
if not self.has_config:
return query
return query.filter(self.entity_filter())
@property
def has_config(self) -> bool:
"""Determine if there is any filter configuration."""
return bool(
self.excluded_entities
or self.excluded_domains
or self.excluded_entity_globs
or self.included_entities
or self.included_domains
or self.included_entity_globs
)
def bake(self, baked_query: BakedQuery) -> None:
"""Update a baked query.
Works the same as apply on a baked_query.
"""
if not self.has_config:
return
baked_query += lambda q: q.filter(self.entity_filter())
def entity_filter(self) -> Any:
"""Generate the entity filter query."""
includes = []
if self.included_domains:
includes.append(
or_(
*[
history_models.States.entity_id.like(f"{domain}.%")
for domain in self.included_domains
]
).self_group()
)
if self.included_entities:
includes.append(history_models.States.entity_id.in_(self.included_entities))
for glob in self.included_entity_globs:
includes.append(_glob_to_like(glob))
excludes = []
if self.excluded_domains:
excludes.append(
or_(
*[
history_models.States.entity_id.like(f"{domain}.%")
for domain in self.excluded_domains
]
).self_group()
)
if self.excluded_entities:
excludes.append(history_models.States.entity_id.in_(self.excluded_entities))
for glob in self.excluded_entity_globs:
excludes.append(_glob_to_like(glob))
if not includes and not excludes:
return None
if includes and not excludes:
return or_(*includes)
if not includes and excludes:
return not_(or_(*excludes))
return or_(*includes) & not_(or_(*excludes))
def _glob_to_like(glob_str: str) -> Any:
"""Translate glob to sql."""
return history_models.States.entity_id.like(glob_str.translate(GLOB_TO_SQL_CHARS))
def _entities_may_have_state_changes_after(
hass: HomeAssistant, entity_ids: Iterable, start_time: dt
) -> bool:

View File

@ -17,12 +17,12 @@ import voluptuous as vol
from homeassistant.components import frontend, websocket_api
from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED
from homeassistant.components.history import (
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.filters import (
Filters,
sqlalchemy_filter_from_include_exclude_conf,
)
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.models import (
process_datetime_to_timestamp,
process_timestamp_to_utc_isoformat,

View File

@ -3,17 +3,17 @@ from __future__ import annotations
from collections.abc import Iterable
from datetime import datetime as dt
from typing import Any
import sqlalchemy
from sqlalchemy import lambda_stmt, select, union_all
from sqlalchemy.orm import Query, aliased
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Select
from homeassistant.components.history import Filters
from homeassistant.components.proximity import DOMAIN as PROXIMITY_DOMAIN
from homeassistant.components.recorder.filters import Filters
from homeassistant.components.recorder.models import (
ENTITY_ID_LAST_UPDATED_INDEX,
LAST_UPDATED_INDEX,
@ -236,7 +236,7 @@ def _all_stmt(
start_day: dt,
end_day: dt,
event_types: tuple[str, ...],
entity_filter: Any | None = None,
entity_filter: ClauseList | None = None,
context_id: str | None = None,
) -> StatementLambdaElement:
"""Generate a logbook query for all entities."""
@ -410,7 +410,7 @@ def _continuous_domain_matcher() -> sqlalchemy.or_:
).self_group()
def _not_uom_attributes_matcher() -> Any:
def _not_uom_attributes_matcher() -> ClauseList:
"""Prefilter ATTR_UNIT_OF_MEASUREMENT as its much faster in sql."""
return ~StateAttributes.shared_attrs.like(
UNIT_OF_MEASUREMENT_JSON_LIKE

View File

@ -0,0 +1,119 @@
"""Provide pre-made queries on top of the recorder component."""
from __future__ import annotations
from sqlalchemy import not_, or_
from sqlalchemy.ext.baked import BakedQuery
from sqlalchemy.sql.elements import ClauseList
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS
from homeassistant.helpers.typing import ConfigType
from .models import States
DOMAIN = "history"
HISTORY_FILTERS = "history_filters"
GLOB_TO_SQL_CHARS = {
42: "%", # *
46: "_", # .
}
def sqlalchemy_filter_from_include_exclude_conf(conf: ConfigType) -> Filters | None:
"""Build a sql filter from config."""
filters = Filters()
if exclude := conf.get(CONF_EXCLUDE):
filters.excluded_entities = exclude.get(CONF_ENTITIES, [])
filters.excluded_domains = exclude.get(CONF_DOMAINS, [])
filters.excluded_entity_globs = exclude.get(CONF_ENTITY_GLOBS, [])
if include := conf.get(CONF_INCLUDE):
filters.included_entities = include.get(CONF_ENTITIES, [])
filters.included_domains = include.get(CONF_DOMAINS, [])
filters.included_entity_globs = include.get(CONF_ENTITY_GLOBS, [])
return filters if filters.has_config else None
class Filters:
"""Container for the configured include and exclude filters."""
def __init__(self) -> None:
"""Initialise the include and exclude filters."""
self.excluded_entities: list[str] = []
self.excluded_domains: list[str] = []
self.excluded_entity_globs: list[str] = []
self.included_entities: list[str] = []
self.included_domains: list[str] = []
self.included_entity_globs: list[str] = []
@property
def has_config(self) -> bool:
"""Determine if there is any filter configuration."""
return bool(
self.excluded_entities
or self.excluded_domains
or self.excluded_entity_globs
or self.included_entities
or self.included_domains
or self.included_entity_globs
)
def bake(self, baked_query: BakedQuery) -> BakedQuery:
"""Update a baked query.
Works the same as apply on a baked_query.
"""
if not self.has_config:
return
baked_query += lambda q: q.filter(self.entity_filter())
def entity_filter(self) -> ClauseList:
"""Generate the entity filter query."""
includes = []
if self.included_domains:
includes.append(
or_(
*[
States.entity_id.like(f"{domain}.%")
for domain in self.included_domains
]
).self_group()
)
if self.included_entities:
includes.append(States.entity_id.in_(self.included_entities))
for glob in self.included_entity_globs:
includes.append(_glob_to_like(glob))
excludes = []
if self.excluded_domains:
excludes.append(
or_(
*[
States.entity_id.like(f"{domain}.%")
for domain in self.excluded_domains
]
).self_group()
)
if self.excluded_entities:
excludes.append(States.entity_id.in_(self.excluded_entities))
for glob in self.excluded_entity_globs:
excludes.append(_glob_to_like(glob))
if not includes and not excludes:
return None
if includes and not excludes:
return or_(*includes)
if not includes and excludes:
return not_(or_(*excludes))
return or_(*includes) & not_(or_(*excludes))
def _glob_to_like(glob_str: str) -> ClauseList:
"""Translate glob to sql."""
return States.entity_id.like(glob_str.translate(GLOB_TO_SQL_CHARS))

View File

@ -25,6 +25,7 @@ from homeassistant.components.websocket_api.const import (
from homeassistant.core import HomeAssistant, State, split_entity_id
import homeassistant.util.dt as dt_util
from .filters import Filters
from .models import (
LazyState,
RecorderRuns,
@ -163,7 +164,7 @@ def get_significant_states(
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any | None = None,
filters: Filters | None = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
minimal_response: bool = False,
@ -205,7 +206,7 @@ def _query_significant_states_with_session(
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
filters: Filters | None = None,
significant_changes_only: bool = True,
no_attributes: bool = False,
) -> list[Row]:
@ -281,7 +282,7 @@ def get_significant_states_with_session(
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
filters: Filters | None = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
minimal_response: bool = False,
@ -330,7 +331,7 @@ def get_full_significant_states_with_session(
start_time: datetime,
end_time: datetime | None = None,
entity_ids: list[str] | None = None,
filters: Any = None,
filters: Filters | None = None,
include_start_time_state: bool = True,
significant_changes_only: bool = True,
no_attributes: bool = False,
@ -549,7 +550,7 @@ def _most_recent_state_ids_subquery(query: Query) -> Query:
def _get_states_baked_query_for_all(
hass: HomeAssistant,
filters: Any | None = None,
filters: Filters | None = None,
no_attributes: bool = False,
) -> BakedQuery:
"""Baked query to get states for all entities."""
@ -573,7 +574,7 @@ def _get_rows_with_session(
utc_point_in_time: datetime,
entity_ids: list[str] | None = None,
run: RecorderRuns | None = None,
filters: Any | None = None,
filters: Filters | None = None,
no_attributes: bool = False,
) -> list[Row]:
"""Return the states at a specific point in time."""
@ -640,7 +641,7 @@ def _sorted_states_to_dict(
states: Iterable[Row],
start_time: datetime,
entity_ids: list[str] | None,
filters: Any = None,
filters: Filters | None = None,
include_start_time_state: bool = True,
minimal_response: bool = False,
no_attributes: bool = False,

View File

@ -2,6 +2,7 @@
import pytest
from homeassistant.components import history
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
from homeassistant.setup import setup_component
@ -13,13 +14,13 @@ def hass_history(hass_recorder):
config = history.CONFIG_SCHEMA(
{
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ["media_player"],
history.CONF_ENTITIES: ["thermostat.test"],
CONF_INCLUDE: {
CONF_DOMAINS: ["media_player"],
CONF_ENTITIES: ["thermostat.test"],
},
history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ["thermostat"],
history.CONF_ENTITIES: ["media_player.test"],
CONF_EXCLUDE: {
CONF_DOMAINS: ["thermostat"],
CONF_ENTITIES: ["media_player.test"],
},
}
}

View File

@ -11,7 +11,7 @@ from pytest import approx
from homeassistant.components import history
from homeassistant.components.recorder.history import get_significant_states
from homeassistant.components.recorder.models import process_timestamp
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
import homeassistant.core as ha
from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component
@ -186,9 +186,7 @@ def test_get_significant_states_exclude_domain(hass_history):
config = history.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["media_player"]}
},
history.DOMAIN: {CONF_EXCLUDE: {CONF_DOMAINS: ["media_player"]}},
}
)
check_significant_states(hass, zero, four, states, config)
@ -207,9 +205,7 @@ def test_get_significant_states_exclude_entity(hass_history):
config = history.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_EXCLUDE: {history.CONF_ENTITIES: ["media_player.test"]}
},
history.DOMAIN: {CONF_EXCLUDE: {CONF_ENTITIES: ["media_player.test"]}},
}
)
check_significant_states(hass, zero, four, states, config)
@ -230,9 +226,9 @@ def test_get_significant_states_exclude(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ["thermostat"],
history.CONF_ENTITIES: ["media_player.test"],
CONF_EXCLUDE: {
CONF_DOMAINS: ["thermostat"],
CONF_ENTITIES: ["media_player.test"],
}
},
}
@ -257,10 +253,8 @@ def test_get_significant_states_exclude_include_entity(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_ENTITIES: ["media_player.test", "thermostat.test"]
},
history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["thermostat"]},
CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test", "thermostat.test"]},
CONF_EXCLUDE: {CONF_DOMAINS: ["thermostat"]},
},
}
)
@ -282,9 +276,7 @@ def test_get_significant_states_include_domain(hass_history):
config = history.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {history.CONF_DOMAINS: ["thermostat", "script"]}
},
history.DOMAIN: {CONF_INCLUDE: {CONF_DOMAINS: ["thermostat", "script"]}},
}
)
check_significant_states(hass, zero, four, states, config)
@ -306,9 +298,7 @@ def test_get_significant_states_include_entity(hass_history):
config = history.CONFIG_SCHEMA(
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {history.CONF_ENTITIES: ["media_player.test"]}
},
history.DOMAIN: {CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test"]}},
}
)
check_significant_states(hass, zero, four, states, config)
@ -330,9 +320,9 @@ def test_get_significant_states_include(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ["thermostat"],
history.CONF_ENTITIES: ["media_player.test"],
CONF_INCLUDE: {
CONF_DOMAINS: ["thermostat"],
CONF_ENTITIES: ["media_player.test"],
}
},
}
@ -359,8 +349,8 @@ def test_get_significant_states_include_exclude_domain(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {history.CONF_DOMAINS: ["media_player"]},
history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["media_player"]},
CONF_INCLUDE: {CONF_DOMAINS: ["media_player"]},
CONF_EXCLUDE: {CONF_DOMAINS: ["media_player"]},
},
}
)
@ -386,8 +376,8 @@ def test_get_significant_states_include_exclude_entity(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {history.CONF_ENTITIES: ["media_player.test"]},
history.CONF_EXCLUDE: {history.CONF_ENTITIES: ["media_player.test"]},
CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test"]},
CONF_EXCLUDE: {CONF_ENTITIES: ["media_player.test"]},
},
}
)
@ -410,13 +400,13 @@ def test_get_significant_states_include_exclude(hass_history):
{
ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ["media_player"],
history.CONF_ENTITIES: ["thermostat.test"],
CONF_INCLUDE: {
CONF_DOMAINS: ["media_player"],
CONF_ENTITIES: ["thermostat.test"],
},
history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ["thermostat"],
history.CONF_ENTITIES: ["media_player.test"],
CONF_EXCLUDE: {
CONF_DOMAINS: ["thermostat"],
CONF_ENTITIES: ["media_player.test"],
},
},
}
@ -503,14 +493,14 @@ def test_get_significant_states_only(hass_history):
def check_significant_states(hass, zero, four, states, config):
"""Check if significant states are retrieved."""
filters = history.Filters()
exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE)
exclude = config[history.DOMAIN].get(CONF_EXCLUDE)
if exclude:
filters.excluded_entities = exclude.get(history.CONF_ENTITIES, [])
filters.excluded_domains = exclude.get(history.CONF_DOMAINS, [])
include = config[history.DOMAIN].get(history.CONF_INCLUDE)
filters.excluded_entities = exclude.get(CONF_ENTITIES, [])
filters.excluded_domains = exclude.get(CONF_DOMAINS, [])
include = config[history.DOMAIN].get(CONF_INCLUDE)
if include:
filters.included_entities = include.get(history.CONF_ENTITIES, [])
filters.included_domains = include.get(history.CONF_DOMAINS, [])
filters.included_entities = include.get(CONF_ENTITIES, [])
filters.included_domains = include.get(CONF_DOMAINS, [])
hist = get_significant_states(hass, zero, four, filters=filters)
assert states == hist
@ -1496,7 +1486,7 @@ async def test_history_during_period_with_use_include_order(
{
history.DOMAIN: {
history.CONF_ORDER: True,
history.CONF_INCLUDE: {
CONF_INCLUDE: {
CONF_ENTITIES: sort_order,
CONF_DOMAINS: ["sensor"],
},