diff --git a/homeassistant/components/logbook/__init__.py b/homeassistant/components/logbook/__init__.py index ee2ae3da4d9..0c614972e1e 100644 --- a/homeassistant/components/logbook/__init__.py +++ b/homeassistant/components/logbook/__init__.py @@ -127,7 +127,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: possible_merged_entities_filter = convert_include_exclude_filter(merged_filter) if not possible_merged_entities_filter.empty_filter: filters = sqlalchemy_filter_from_include_exclude_conf(merged_filter) - entities_filter = possible_merged_entities_filter + entities_filter = possible_merged_entities_filter.get_filter() else: filters = None entities_filter = None diff --git a/homeassistant/components/logbook/helpers.py b/homeassistant/components/logbook/helpers.py index c8f55331de1..3a1ec971b54 100644 --- a/homeassistant/components/logbook/helpers.py +++ b/homeassistant/components/logbook/helpers.py @@ -23,7 +23,6 @@ from homeassistant.core import ( split_entity_id, ) from homeassistant.helpers import device_registry as dr, entity_registry as er -from homeassistant.helpers.entityfilter import EntityFilter from homeassistant.helpers.event import async_track_state_change_event from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN @@ -104,7 +103,7 @@ def extract_attr(source: dict[str, Any], attr: str) -> list[str]: @callback def event_forwarder_filtered( target: Callable[[Event], None], - entities_filter: EntityFilter | None, + entities_filter: Callable[[str], bool] | None, entity_ids: list[str] | None, device_ids: list[str] | None, ) -> Callable[[Event], None]: @@ -159,7 +158,7 @@ def async_subscribe_events( subscriptions: list[CALLBACK_TYPE], target: Callable[[Event], None], event_types: tuple[str, ...], - entities_filter: EntityFilter | None, + entities_filter: Callable[[str], bool] | None, entity_ids: list[str] | None, device_ids: list[str] | None, ) -> None: diff --git a/homeassistant/components/logbook/models.py b/homeassistant/components/logbook/models.py index 86dcfdf82c5..e351ee6bb61 100644 --- a/homeassistant/components/logbook/models.py +++ b/homeassistant/components/logbook/models.py @@ -16,7 +16,6 @@ from homeassistant.components.recorder.models import ( ) from homeassistant.const import ATTR_ICON, EVENT_STATE_CHANGED from homeassistant.core import Context, Event, State, callback -from homeassistant.helpers.entityfilter import EntityFilter import homeassistant.util.dt as dt_util from homeassistant.util.json import json_loads from homeassistant.util.ulid import ulid_to_bytes @@ -30,7 +29,7 @@ class LogbookConfig: str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] ] sqlalchemy_filter: Filters | None = None - entity_filter: EntityFilter | None = None + entity_filter: Callable[[str], bool] | None = None class LazyEventPartialState: diff --git a/homeassistant/components/logbook/rest_api.py b/homeassistant/components/logbook/rest_api.py index a1a7db3ed2c..57d0a6695c7 100644 --- a/homeassistant/components/logbook/rest_api.py +++ b/homeassistant/components/logbook/rest_api.py @@ -1,6 +1,7 @@ """Event parser and human readable log generator.""" from __future__ import annotations +from collections.abc import Callable from datetime import timedelta from http import HTTPStatus from typing import Any, cast @@ -14,7 +15,6 @@ from homeassistant.components.recorder.filters import Filters from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import InvalidEntityFormatError from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.entityfilter import EntityFilter from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util @@ -27,7 +27,7 @@ def async_setup( hass: HomeAssistant, conf: ConfigType, filters: Filters | None, - entities_filter: EntityFilter | None, + entities_filter: Callable[[str], bool] | None, ) -> None: """Set up the logbook rest API.""" hass.http.register_view(LogbookView(conf, filters, entities_filter)) @@ -44,7 +44,7 @@ class LogbookView(HomeAssistantView): self, config: dict[str, Any], filters: Filters | None, - entities_filter: EntityFilter | None, + entities_filter: Callable[[str], bool] | None, ) -> None: """Initialize the logbook view.""" self.config = config diff --git a/homeassistant/components/logbook/websocket_api.py b/homeassistant/components/logbook/websocket_api.py index c4e6b9814f4..4afa40cb14f 100644 --- a/homeassistant/components/logbook/websocket_api.py +++ b/homeassistant/components/logbook/websocket_api.py @@ -15,7 +15,6 @@ from homeassistant.components.recorder import get_instance from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback -from homeassistant.helpers.entityfilter import EntityFilter from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.json import JSON_DUMP import homeassistant.util.dt as dt_util @@ -357,7 +356,7 @@ async def ws_event_stream( ) _unsub() - entities_filter: EntityFilter | None = None + entities_filter: Callable[[str], bool] | None = None if not event_processor.limited_select: logbook_config: LogbookConfig = hass.data[DOMAIN] entities_filter = logbook_config.entity_filter diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 7b43abd8dde..72d825d9e78 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -135,7 +135,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: exclude_attributes_by_domain: dict[str, set[str]] = {} hass.data[EXCLUDE_ATTRIBUTES] = exclude_attributes_by_domain conf = config[DOMAIN] - entity_filter = convert_include_exclude_filter(conf) + entity_filter = convert_include_exclude_filter(conf).get_filter() auto_purge = conf[CONF_AUTO_PURGE] auto_repack = conf[CONF_AUTO_REPACK] keep_days = conf[CONF_PURGE_KEEP_DAYS] diff --git a/homeassistant/helpers/entityfilter.py b/homeassistant/helpers/entityfilter.py index 057e8f0955e..a9d3ccad138 100644 --- a/homeassistant/helpers/entityfilter.py +++ b/homeassistant/helpers/entityfilter.py @@ -35,7 +35,14 @@ class EntityFilter: self._exclude_d = set(config[CONF_EXCLUDE_DOMAINS]) self._include_eg = _convert_globs_to_pattern(config[CONF_INCLUDE_ENTITY_GLOBS]) self._exclude_eg = _convert_globs_to_pattern(config[CONF_EXCLUDE_ENTITY_GLOBS]) - self._filter: Callable[[str], bool] | None = None + self._filter = _generate_filter_from_sets_and_pattern_lists( + self._include_d, + self._include_e, + self._exclude_d, + self._exclude_e, + self._include_eg, + self._exclude_eg, + ) def explicitly_included(self, entity_id: str) -> bool: """Check if an entity is explicitly included.""" @@ -49,17 +56,12 @@ class EntityFilter: bool(self._exclude_eg and self._exclude_eg.match(entity_id)) ) + def get_filter(self) -> Callable[[str], bool]: + """Return the filter function.""" + return self._filter + def __call__(self, entity_id: str) -> bool: """Run the filter.""" - if self._filter is None: - self._filter = _generate_filter_from_sets_and_pattern_lists( - self._include_d, - self._include_e, - self._exclude_d, - self._exclude_e, - self._include_eg, - self._exclude_eg, - ) return self._filter(entity_id) diff --git a/tests/helpers/test_entityfilter.py b/tests/helpers/test_entityfilter.py index 2141c286914..48bc8110ec5 100644 --- a/tests/helpers/test_entityfilter.py +++ b/tests/helpers/test_entityfilter.py @@ -395,6 +395,28 @@ def test_explicitly_included() -> None: assert filt.explicitly_excluded("light.kitchen") +def test_get_filter() -> None: + """Test we can get the underlying filter.""" + conf = { + "include": { + "domains": ["light"], + "entity_globs": ["sensor.kitchen_*"], + "entities": ["switch.kitchen"], + }, + "exclude": { + "domains": ["cover"], + "entity_globs": ["sensor.weather_*"], + "entities": ["light.kitchen"], + }, + } + filt: EntityFilter = INCLUDE_EXCLUDE_FILTER_SCHEMA(conf) + underlying_filter = filt.get_filter() + assert underlying_filter("light.any") + assert not underlying_filter("switch.other") + assert underlying_filter("sensor.kitchen_4") + assert underlying_filter("switch.kitchen") + + def test_complex_include_exclude_filter() -> None: """Test a complex include exclude filter.""" conf = {