Avoid selecting attributes in the history api when `no_attributes` is passed (#68352)

pull/68404/head^2
J. Nick Koston 2022-03-19 23:47:22 -10:00 committed by GitHub
parent a0a96dab05
commit 816695cc96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 51 deletions

View File

@ -224,6 +224,7 @@ class HistoryPeriodView(HomeAssistantView):
) )
minimal_response = "minimal_response" in request.query minimal_response = "minimal_response" in request.query
no_attributes = "no_attributes" in request.query
hass = request.app["hass"] hass = request.app["hass"]
@ -245,6 +246,7 @@ class HistoryPeriodView(HomeAssistantView):
include_start_time_state, include_start_time_state,
significant_changes_only, significant_changes_only,
minimal_response, minimal_response,
no_attributes,
), ),
) )
@ -257,6 +259,7 @@ class HistoryPeriodView(HomeAssistantView):
include_start_time_state, include_start_time_state,
significant_changes_only, significant_changes_only,
minimal_response, minimal_response,
no_attributes,
): ):
"""Fetch significant stats from the database as json.""" """Fetch significant stats from the database as json."""
timer_start = time.perf_counter() timer_start = time.perf_counter()
@ -272,6 +275,7 @@ class HistoryPeriodView(HomeAssistantView):
include_start_time_state, include_start_time_state,
significant_changes_only, significant_changes_only,
minimal_response, minimal_response,
no_attributes,
) )
result = list(result.values()) result = list(result.values())

View File

@ -2,15 +2,17 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from datetime import datetime
from itertools import groupby from itertools import groupby
import logging import logging
import time import time
from sqlalchemy import and_, bindparam, func from sqlalchemy import Text, and_, bindparam, func
from sqlalchemy.ext import baked from sqlalchemy.ext import baked
from sqlalchemy.sql.expression import literal
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.core import split_entity_id from homeassistant.core import HomeAssistant, State, split_entity_id
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .models import ( from .models import (
@ -44,13 +46,21 @@ NEED_ATTRIBUTE_DOMAINS = {
"water_heater", "water_heater",
} }
QUERY_STATES = [ BASE_STATES = [
States.domain, States.domain,
States.entity_id, States.entity_id,
States.state, States.state,
States.attributes,
States.last_changed, States.last_changed,
States.last_updated, States.last_updated,
]
QUERY_STATE_NO_ATTR = [
*BASE_STATES,
literal(value=None, type_=Text).label("attributes"),
literal(value=None, type_=Text).label("shared_attrs"),
]
QUERY_STATES = [
*BASE_STATES,
States.attributes,
StateAttributes.shared_attrs, StateAttributes.shared_attrs,
] ]
@ -78,6 +88,7 @@ def get_significant_states_with_session(
include_start_time_state=True, include_start_time_state=True,
significant_changes_only=True, significant_changes_only=True,
minimal_response=False, minimal_response=False,
no_attributes=False,
): ):
""" """
Return states changes during UTC period start_time - end_time. Return states changes during UTC period start_time - end_time.
@ -92,10 +103,8 @@ def get_significant_states_with_session(
thermostat so that we get current temperature in our graphs). thermostat so that we get current temperature in our graphs).
""" """
timer_start = time.perf_counter() timer_start = time.perf_counter()
query_keys = QUERY_STATE_NO_ATTR if no_attributes else QUERY_STATES
baked_query = hass.data[HISTORY_BAKERY]( baked_query = hass.data[HISTORY_BAKERY](lambda session: session.query(*query_keys))
lambda session: session.query(*QUERY_STATES)
)
if significant_changes_only: if significant_changes_only:
baked_query += lambda q: q.filter( baked_query += lambda q: q.filter(
@ -120,6 +129,7 @@ def get_significant_states_with_session(
if end_time is not None: if end_time is not None:
baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time")) baked_query += lambda q: q.filter(States.last_updated < bindparam("end_time"))
if not no_attributes:
baked_query += lambda q: q.outerjoin( baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
@ -144,14 +154,25 @@ def get_significant_states_with_session(
filters, filters,
include_start_time_state, include_start_time_state,
minimal_response, minimal_response,
no_attributes,
) )
def state_changes_during_period(hass, start_time, end_time=None, entity_id=None): def state_changes_during_period(
hass: HomeAssistant,
start_time: datetime,
end_time: datetime | None = None,
entity_id: str | None = None,
no_attributes: bool = False,
descending: bool = False,
limit: int | None = None,
include_start_time_state: bool = True,
) -> dict[str, list[State]]:
"""Return states changes during UTC period start_time - end_time.""" """Return states changes during UTC period start_time - end_time."""
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
query_keys = QUERY_STATE_NO_ATTR if no_attributes else QUERY_STATES
baked_query = hass.data[HISTORY_BAKERY]( baked_query = hass.data[HISTORY_BAKERY](
lambda session: session.query(*QUERY_STATES) lambda session: session.query(*query_keys)
) )
baked_query += lambda q: q.filter( baked_query += lambda q: q.filter(
@ -168,10 +189,16 @@ def state_changes_during_period(hass, start_time, end_time=None, entity_id=None)
baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id")) baked_query += lambda q: q.filter_by(entity_id=bindparam("entity_id"))
entity_id = entity_id.lower() entity_id = entity_id.lower()
if not no_attributes:
baked_query += lambda q: q.outerjoin( baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
baked_query += lambda q: q.order_by(States.entity_id, States.last_updated)
last_updated = States.last_updated.desc() if descending else States.last_updated
baked_query += lambda q: q.order_by(States.entity_id, last_updated)
if limit:
baked_query += lambda q: q.limit(limit)
states = execute( states = execute(
baked_query(session).params( baked_query(session).params(
@ -181,7 +208,14 @@ def state_changes_during_period(hass, start_time, end_time=None, entity_id=None)
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
return _sorted_states_to_dict(hass, session, states, start_time, entity_ids) return _sorted_states_to_dict(
hass,
session,
states,
start_time,
entity_ids,
include_start_time_state=include_start_time_state,
)
def get_last_state_changes(hass, number_of_states, entity_id): def get_last_state_changes(hass, number_of_states, entity_id):
@ -225,7 +259,14 @@ def get_last_state_changes(hass, number_of_states, entity_id):
) )
def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None): def get_states(
hass,
utc_point_in_time,
entity_ids=None,
run=None,
filters=None,
no_attributes=False,
):
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if run is None: if run is None:
run = recorder.run_information_from_instance(hass, utc_point_in_time) run = recorder.run_information_from_instance(hass, utc_point_in_time)
@ -236,17 +277,23 @@ def get_states(hass, utc_point_in_time, entity_ids=None, run=None, filters=None)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
return _get_states_with_session( return _get_states_with_session(
hass, session, utc_point_in_time, entity_ids, run, filters hass, session, utc_point_in_time, entity_ids, run, filters, no_attributes
) )
def _get_states_with_session( def _get_states_with_session(
hass, session, utc_point_in_time, entity_ids=None, run=None, filters=None hass,
session,
utc_point_in_time,
entity_ids=None,
run=None,
filters=None,
no_attributes=False,
): ):
"""Return the states at a specific point in time.""" """Return the states at a specific point in time."""
if entity_ids and len(entity_ids) == 1: if entity_ids and len(entity_ids) == 1:
return _get_single_entity_states_with_session( return _get_single_entity_states_with_session(
hass, session, utc_point_in_time, entity_ids[0] hass, session, utc_point_in_time, entity_ids[0], no_attributes
) )
if run is None: if run is None:
@ -258,7 +305,8 @@ def _get_states_with_session(
# We have more than one entity to look at so we need to do a query on states # We have more than one entity to look at so we need to do a query on states
# since the last recorder run started. # since the last recorder run started.
query = session.query(*QUERY_STATES) query_keys = QUERY_STATE_NO_ATTR if no_attributes else QUERY_STATES
query = session.query(*query_keys)
if entity_ids: if entity_ids:
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
@ -278,7 +326,9 @@ def _get_states_with_session(
query = query.join( query = query.join(
most_recent_state_ids, most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id, States.state_id == most_recent_state_ids.c.max_state_id,
).outerjoin( )
if not no_attributes:
query = query.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
else: else:
@ -318,6 +368,7 @@ def _get_states_with_session(
query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) query = query.filter(~States.domain.in_(IGNORE_DOMAINS))
if filters: if filters:
query = filters.apply(query) query = filters.apply(query)
if not no_attributes:
query = query.outerjoin( query = query.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
@ -326,16 +377,18 @@ def _get_states_with_session(
return [LazyState(row, attr_cache) for row in execute(query)] return [LazyState(row, attr_cache) for row in execute(query)]
def _get_single_entity_states_with_session(hass, session, utc_point_in_time, entity_id): def _get_single_entity_states_with_session(
hass, session, utc_point_in_time, entity_id, no_attributes=False
):
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
baked_query = hass.data[HISTORY_BAKERY]( query_keys = QUERY_STATE_NO_ATTR if no_attributes else QUERY_STATES
lambda session: session.query(*QUERY_STATES) baked_query = hass.data[HISTORY_BAKERY](lambda session: session.query(*query_keys))
)
baked_query += lambda q: q.filter( baked_query += lambda q: q.filter(
States.last_updated < bindparam("utc_point_in_time"), States.last_updated < bindparam("utc_point_in_time"),
States.entity_id == bindparam("entity_id"), States.entity_id == bindparam("entity_id"),
) )
if not no_attributes:
baked_query += lambda q: q.outerjoin( baked_query += lambda q: q.outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id StateAttributes, States.attributes_id == StateAttributes.attributes_id
) )
@ -358,6 +411,7 @@ def _sorted_states_to_dict(
filters=None, filters=None,
include_start_time_state=True, include_start_time_state=True,
minimal_response=False, minimal_response=False,
no_attributes=False,
): ):
"""Convert SQL results into JSON friendly data structure. """Convert SQL results into JSON friendly data structure.
@ -381,7 +435,13 @@ def _sorted_states_to_dict(
if include_start_time_state: if include_start_time_state:
run = recorder.run_information_from_instance(hass, start_time) run = recorder.run_information_from_instance(hass, start_time)
for state in _get_states_with_session( for state in _get_states_with_session(
hass, session, start_time, entity_ids, run=run, filters=filters hass,
session,
start_time,
entity_ids,
run=run,
filters=filters,
no_attributes=no_attributes,
): ):
state.last_changed = start_time state.last_changed = start_time
state.last_updated = start_time state.last_updated = start_time
@ -440,7 +500,7 @@ def _sorted_states_to_dict(
return {key: val for key, val in result.items() if val} return {key: val for key, val in result.items() if val}
def get_state(hass, utc_point_in_time, entity_id, run=None): def get_state(hass, utc_point_in_time, entity_id, run=None, no_attributes=False):
"""Return a state at a specific point in time.""" """Return a state at a specific point in time."""
states = get_states(hass, utc_point_in_time, (entity_id,), run) states = get_states(hass, utc_point_in_time, (entity_id,), run, None, no_attributes)
return states[0] if states else None return states[0] if states else None

View File

@ -17,8 +17,12 @@ from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
from tests.common import init_recorder_component from tests.common import async_init_recorder_component, init_recorder_component
from tests.components.recorder.common import trigger_db_commit, wait_recording_done from tests.components.recorder.common import (
async_wait_recording_done_without_instance,
trigger_db_commit,
wait_recording_done,
)
@pytest.mark.usefixtures("hass_history") @pytest.mark.usefixtures("hass_history")
@ -604,14 +608,36 @@ async def test_fetch_period_api_with_use_include_order(hass, hass_client):
async def test_fetch_period_api_with_minimal_response(hass, hass_client): async def test_fetch_period_api_with_minimal_response(hass, hass_client):
"""Test the fetch period view for history with minimal_response.""" """Test the fetch period view for history with minimal_response."""
await hass.async_add_executor_job(init_recorder_component, hass) await async_init_recorder_component(hass)
now = dt_util.utcnow()
await async_setup_component(hass, "history", {}) await async_setup_component(hass, "history", {})
await hass.async_add_executor_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
hass.states.async_set("sensor.power", 0, {"attr": "any"})
await async_wait_recording_done_without_instance(hass)
hass.states.async_set("sensor.power", 50, {"attr": "any"})
await async_wait_recording_done_without_instance(hass)
hass.states.async_set("sensor.power", 23, {"attr": "any"})
await async_wait_recording_done_without_instance(hass)
client = await hass_client() client = await hass_client()
response = await client.get( response = await client.get(
f"/api/history/period/{dt_util.utcnow().isoformat()}?minimal_response" f"/api/history/period/{now.isoformat()}?filter_entity_id=sensor.power&minimal_response&no_attributes"
) )
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_json = await response.json()
assert len(response_json[0]) == 3
state_list = response_json[0]
assert state_list[0]["entity_id"] == "sensor.power"
assert state_list[0]["attributes"] == {}
assert state_list[0]["state"] == "0"
assert "attributes" not in state_list[1]
assert "entity_id" not in state_list[1]
assert state_list[1]["state"] == "50"
assert state_list[2]["entity_id"] == "sensor.power"
assert state_list[2]["attributes"] == {}
assert state_list[2]["state"] == "23"
async def test_fetch_period_api_with_no_timestamp(hass, hass_client): async def test_fetch_period_api_with_no_timestamp(hass, hass_client):

View File

@ -5,6 +5,8 @@ from datetime import timedelta
import json import json
from unittest.mock import patch, sentinel from unittest.mock import patch, sentinel
import pytest
from homeassistant.components.recorder import history from homeassistant.components.recorder import history
from homeassistant.components.recorder.models import process_timestamp from homeassistant.components.recorder.models import process_timestamp
import homeassistant.core as ha import homeassistant.core as ha
@ -15,11 +17,9 @@ from tests.common import mock_state_change_event
from tests.components.recorder.common import wait_recording_done from tests.components.recorder.common import wait_recording_done
def test_get_states(hass_recorder): def _setup_get_states(hass):
"""Test getting states at a specific point in time.""" """Set up for testing get_states."""
hass = hass_recorder()
states = [] states = []
now = dt_util.utcnow() now = dt_util.utcnow()
with patch("homeassistant.components.recorder.dt_util.utcnow", return_value=now): with patch("homeassistant.components.recorder.dt_util.utcnow", return_value=now):
for i in range(5): for i in range(5):
@ -48,6 +48,13 @@ def test_get_states(hass_recorder):
wait_recording_done(hass) wait_recording_done(hass)
return now, future, states
def test_get_states(hass_recorder):
"""Test getting states at a specific point in time."""
hass = hass_recorder()
now, future, states = _setup_get_states(hass)
# Get states returns everything before POINT for all entities # Get states returns everything before POINT for all entities
for state1, state2 in zip( for state1, state2 in zip(
states, states,
@ -75,14 +82,65 @@ def test_get_states(hass_recorder):
assert history.get_state(hass, time_before_recorder_ran, "demo.id") is None assert history.get_state(hass, time_before_recorder_ran, "demo.id") is None
def test_state_changes_during_period(hass_recorder): def test_get_states_no_attributes(hass_recorder):
"""Test getting states without attributes at a specific point in time."""
hass = hass_recorder()
now, future, states = _setup_get_states(hass)
for state in states:
state.attributes = {}
# Get states returns everything before POINT for all entities
for state1, state2 in zip(
states,
sorted(
history.get_states(hass, future, no_attributes=True),
key=lambda state: state.entity_id,
),
):
assert state1 == state2
# Get states returns everything before POINT for tested entities
entities = [f"test.point_in_time_{i % 5}" for i in range(5)]
for state1, state2 in zip(
states,
sorted(
history.get_states(hass, future, entities, no_attributes=True),
key=lambda state: state.entity_id,
),
):
assert state1 == state2
# Test get_state here because we have a DB setup
assert states[0] == history.get_state(
hass, future, states[0].entity_id, no_attributes=True
)
time_before_recorder_ran = now - timedelta(days=1000)
assert history.get_states(hass, time_before_recorder_ran, no_attributes=True) == []
assert (
history.get_state(hass, time_before_recorder_ran, "demo.id", no_attributes=True)
is None
)
@pytest.mark.parametrize(
"attributes, no_attributes, limit",
[
({"attr": True}, False, 5000),
({}, True, 5000),
({"attr": True}, False, 3),
({}, True, 3),
],
)
def test_state_changes_during_period(hass_recorder, attributes, no_attributes, limit):
"""Test state change during period.""" """Test state change during period."""
hass = hass_recorder() hass = hass_recorder()
entity_id = "media_player.test" entity_id = "media_player.test"
def set_state(state): def set_state(state):
"""Set the state.""" """Set the state."""
hass.states.set(entity_id, state) hass.states.set(entity_id, state, attributes)
wait_recording_done(hass) wait_recording_done(hass)
return hass.states.get(entity_id) return hass.states.get(entity_id)
@ -106,9 +164,11 @@ def test_state_changes_during_period(hass_recorder):
set_state("Netflix") set_state("Netflix")
set_state("Plex") set_state("Plex")
hist = history.state_changes_during_period(hass, start, end, entity_id) hist = history.state_changes_during_period(
hass, start, end, entity_id, no_attributes, limit=limit
)
assert states == hist[entity_id] assert states[:limit] == hist[entity_id]
def test_get_last_state_changes(hass_recorder): def test_get_last_state_changes(hass_recorder):