diff --git a/homeassistant/components/influxdb/__init__.py b/homeassistant/components/influxdb/__init__.py index 9823d57e200..94a68c25504 100644 --- a/homeassistant/components/influxdb/__init__.py +++ b/homeassistant/components/influxdb/__init__.py @@ -12,15 +12,12 @@ from influxdb_client import InfluxDBClient as InfluxDBClientV2 from influxdb_client.client.write_api import ASYNCHRONOUS, SYNCHRONOUS from influxdb_client.rest import ApiException import requests.exceptions +import urllib3.exceptions import voluptuous as vol from homeassistant.const import ( CONF_API_VERSION, - CONF_DOMAINS, - CONF_ENTITIES, - CONF_EXCLUDE, CONF_HOST, - CONF_INCLUDE, CONF_PASSWORD, CONF_PATH, CONF_PORT, @@ -37,6 +34,10 @@ from homeassistant.const import ( from homeassistant.helpers import event as event_helper, state as state_helper import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_values import EntityValues +from homeassistant.helpers.entityfilter import ( + INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA, + convert_include_exclude_filter, +) _LOGGER = logging.getLogger(__name__) @@ -141,24 +142,8 @@ COMPONENT_CONFIG_SCHEMA_CONNECTION = { _CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string}) -_CONFIG_SCHEMA = vol.Schema( +_CONFIG_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend( { - vol.Optional(CONF_EXCLUDE, default={}): vol.Schema( - { - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): vol.All( - cv.ensure_list, [cv.string] - ), - } - ), - vol.Optional(CONF_INCLUDE, default={}): vol.Schema( - { - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): vol.All( - cv.ensure_list, [cv.string] - ), - } - ), vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int, vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string, vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, @@ -253,12 +238,7 @@ def setup(hass, config): if CONF_SSL in conf: kwargs["ssl"] = conf[CONF_SSL] - include = conf.get(CONF_INCLUDE, {}) - exclude = conf.get(CONF_EXCLUDE, {}) - whitelist_e = set(include.get(CONF_ENTITIES, [])) - whitelist_d = set(include.get(CONF_DOMAINS, [])) - blacklist_e = set(exclude.get(CONF_ENTITIES, [])) - blacklist_d = set(exclude.get(CONF_DOMAINS, [])) + entity_filter = convert_include_exclude_filter(conf) tags = conf.get(CONF_TAGS) tags_attributes = conf.get(CONF_TAGS_ATTRIBUTES) default_measurement = conf.get(CONF_DEFAULT_MEASUREMENT) @@ -285,7 +265,7 @@ def setup(hass, config): ) event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) return True - except ApiException as exc: + except (ApiException, urllib3.exceptions.HTTPError) as exc: _LOGGER.error( "Bucket is not accessible due to '%s', please " "check your entries in the configuration file (url, org, " @@ -303,19 +283,11 @@ def setup(hass, config): if ( state is None or state.state in (STATE_UNKNOWN, "", STATE_UNAVAILABLE) - or state.entity_id in blacklist_e - or state.domain in blacklist_d + or not entity_filter(state.entity_id) ): return try: - if ( - (whitelist_e or whitelist_d) - and state.entity_id not in whitelist_e - and state.domain not in whitelist_d - ): - return - _include_state = _include_value = False _state_as_value = float(state.state) diff --git a/tests/components/influxdb/test_init.py b/tests/components/influxdb/test_init.py index f9514f7ebff..29247bec9c8 100644 --- a/tests/components/influxdb/test_init.py +++ b/tests/components/influxdb/test_init.py @@ -1,4 +1,5 @@ """The tests for the InfluxDB component.""" +from dataclasses import dataclass import datetime import pytest @@ -11,6 +12,7 @@ from homeassistant.const import ( STATE_STANDBY, UNIT_PERCENTAGE, ) +from homeassistant.core import split_entity_id from homeassistant.setup import async_setup_component from tests.async_mock import MagicMock, Mock, call, patch @@ -23,6 +25,14 @@ BASE_V2_CONFIG = { } +@dataclass +class FilterTest: + """Class for capturing a filter test.""" + + id: str + should_pass: bool + + @pytest.fixture(autouse=True) def mock_batch_timeout(hass, monkeypatch): """Mock the event bus listener and the batch timeout for tests.""" @@ -421,43 +431,22 @@ async def test_event_listener_states( write_api.reset_mock() -@pytest.mark.parametrize( - "mock_client, config_ext, get_write_api, get_mock_call", - [ - ( - influxdb.DEFAULT_API_VERSION, - BASE_V1_CONFIG, - _get_write_api_mock_v1, - influxdb.DEFAULT_API_VERSION, - ), - ( - influxdb.API_VERSION_2, - BASE_V2_CONFIG, - _get_write_api_mock_v2, - influxdb.API_VERSION_2, - ), - ], - indirect=["mock_client", "get_mock_call"], -) -async def test_event_listener_blacklist( - hass, mock_client, config_ext, get_write_api, get_mock_call -): - """Test the event listener against a blacklist.""" - handler_method = await _setup(hass, mock_client, config_ext, get_write_api) - - for entity_id in ("ok", "blacklisted"): +def execute_filter_test(hass, tests, handler_method, write_api, get_mock_call): + """Execute all tests for a given filtering test.""" + for test in tests: + domain, entity_id = split_entity_id(test.id) state = MagicMock( state=1, - domain="fake", - entity_id=f"fake.{entity_id}", + domain=domain, + entity_id=test.id, object_id=entity_id, attributes={}, ) event = MagicMock(data={"new_state": state}, time_fired=12345) body = [ { - "measurement": f"fake.{entity_id}", - "tags": {"domain": "fake", "entity_id": entity_id}, + "measurement": test.id, + "tags": {"domain": domain, "entity_id": entity_id}, "time": 12345, "fields": {"value": 1}, } @@ -465,9 +454,8 @@ async def test_event_listener_blacklist( handler_method(event) hass.data[influxdb.DOMAIN].block_till_done() - write_api = get_write_api(mock_client) - if entity_id == "ok": - assert write_api.call_count == 1 + if test.should_pass: + write_api.assert_called_once() assert write_api.call_args == get_mock_call(body) else: assert not write_api.called @@ -492,94 +480,20 @@ async def test_event_listener_blacklist( ], indirect=["mock_client", "get_mock_call"], ) -async def test_event_listener_blacklist_domain( +async def test_event_listener_denylist( hass, mock_client, config_ext, get_write_api, get_mock_call ): - """Test the event listener against a domain blacklist.""" - handler_method = await _setup(hass, mock_client, config_ext, get_write_api) - - for domain in ("ok", "another_fake"): - state = MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - handler_method(event) - hass.data[influxdb.DOMAIN].block_till_done() - - write_api = get_write_api(mock_client) - if domain == "ok": - assert write_api.call_count == 1 - assert write_api.call_args == get_mock_call(body) - else: - assert not write_api.called - write_api.reset_mock() - - -@pytest.mark.parametrize( - "mock_client, config_ext, get_write_api, get_mock_call", - [ - ( - influxdb.DEFAULT_API_VERSION, - BASE_V1_CONFIG, - _get_write_api_mock_v1, - influxdb.DEFAULT_API_VERSION, - ), - ( - influxdb.API_VERSION_2, - BASE_V2_CONFIG, - _get_write_api_mock_v2, - influxdb.API_VERSION_2, - ), - ], - indirect=["mock_client", "get_mock_call"], -) -async def test_event_listener_whitelist( - hass, mock_client, config_ext, get_write_api, get_mock_call -): - """Test the event listener against a whitelist.""" - config = {"include": {"entities": ["fake.included"]}} + """Test the event listener against a denylist.""" + config = {"exclude": {"entities": ["fake.denylisted"]}, "include": {}} config.update(config_ext) handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) - for entity_id in ("included", "default"): - state = MagicMock( - state=1, - domain="fake", - entity_id=f"fake.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"fake.{entity_id}", - "tags": {"domain": "fake", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - handler_method(event) - hass.data[influxdb.DOMAIN].block_till_done() - - write_api = get_write_api(mock_client) - if entity_id == "included": - assert write_api.call_count == 1 - assert write_api.call_args == get_mock_call(body) - else: - assert not write_api.called - write_api.reset_mock() + tests = [ + FilterTest("fake.ok", True), + FilterTest("fake.denylisted", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) @pytest.mark.parametrize( @@ -600,41 +514,20 @@ async def test_event_listener_whitelist( ], indirect=["mock_client", "get_mock_call"], ) -async def test_event_listener_whitelist_domain( +async def test_event_listener_denylist_domain( hass, mock_client, config_ext, get_write_api, get_mock_call ): - """Test the event listener against a domain whitelist.""" - config = {"include": {"domains": ["fake"]}} + """Test the event listener against a domain denylist.""" + config = {"exclude": {"domains": ["another_fake"]}, "include": {}} config.update(config_ext) handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) - for domain in ("fake", "another_fake"): - state = MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - handler_method(event) - hass.data[influxdb.DOMAIN].block_till_done() - - write_api = get_write_api(mock_client) - if domain == "fake": - assert write_api.call_count == 1 - assert write_api.call_args == get_mock_call(body) - else: - assert not write_api.called - write_api.reset_mock() + tests = [ + FilterTest("fake.ok", True), + FilterTest("another_fake.denylisted", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) @pytest.mark.parametrize( @@ -655,69 +548,212 @@ async def test_event_listener_whitelist_domain( ], indirect=["mock_client", "get_mock_call"], ) -async def test_event_listener_whitelist_domain_and_entities( +async def test_event_listener_denylist_glob( hass, mock_client, config_ext, get_write_api, get_mock_call ): - """Test the event listener against a domain and entity whitelist.""" - config = {"include": {"domains": ["fake"], "entities": ["other.one"]}} + """Test the event listener against a glob denylist.""" + config = {"exclude": {"entity_globs": ["*.excluded_*"]}, "include": {}} config.update(config_ext) handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) - for domain in ("fake", "another_fake"): - state = MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - handler_method(event) - hass.data[influxdb.DOMAIN].block_till_done() + tests = [ + FilterTest("fake.ok", True), + FilterTest("fake.excluded_entity", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) - write_api = get_write_api(mock_client) - if domain == "fake": - assert write_api.call_count == 1 - assert write_api.call_args == get_mock_call(body) - else: - assert not write_api.called - write_api.reset_mock() - for entity_id in ("one", "two"): - state = MagicMock( - state=1, - domain="other", - entity_id=f"other.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"other.{entity_id}", - "tags": {"domain": "other", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - handler_method(event) - hass.data[influxdb.DOMAIN].block_till_done() +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_allowlist( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against an allowlist.""" + config = {"include": {"entities": ["fake.included"]}, "exclude": {}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) - write_api = get_write_api(mock_client) - if entity_id == "one": - assert write_api.call_count == 1 - assert write_api.call_args == get_mock_call(body) - else: - assert not write_api.called - write_api.reset_mock() + tests = [ + FilterTest("fake.included", True), + FilterTest("fake.excluded", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_allowlist_domain( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a domain allowlist.""" + config = {"include": {"domains": ["fake"]}, "exclude": {}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) + + tests = [ + FilterTest("fake.ok", True), + FilterTest("another_fake.excluded", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_allowlist_glob( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a glob allowlist.""" + config = {"include": {"entity_globs": ["*.included_*"]}, "exclude": {}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) + + tests = [ + FilterTest("fake.included_entity", True), + FilterTest("fake.denied", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_filtered_allowlist( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against an allowlist filtered by denylist.""" + config = { + "include": { + "domains": ["fake"], + "entities": ["another_fake.included"], + "entity_globs": "*.included_*", + }, + "exclude": { + "entities": ["fake.excluded"], + "domains": ["another_fake"], + "entity_globs": "*.excluded_*", + }, + } + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) + + tests = [ + FilterTest("fake.ok", True), + FilterTest("another_fake.included", True), + FilterTest("test.included_entity", True), + FilterTest("fake.excluded", False), + FilterTest("another_fake.denied", False), + FilterTest("fake.excluded_entity", False), + FilterTest("another_fake.included_entity", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_filtered_denylist( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a domain/glob denylist with an entity id allowlist.""" + config = { + "include": {"entities": ["another_fake.included", "fake.excluded_pass"]}, + "exclude": {"domains": ["another_fake"], "entity_globs": "*.excluded_*"}, + } + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + write_api = get_write_api(mock_client) + + tests = [ + FilterTest("fake.ok", True), + FilterTest("another_fake.included", True), + FilterTest("fake.excluded_pass", True), + FilterTest("another_fake.denied", False), + FilterTest("fake.excluded_entity", False), + ] + execute_filter_test(hass, tests, handler_method, write_api, get_mock_call) @pytest.mark.parametrize(