diff --git a/homeassistant/components/influxdb/__init__.py b/homeassistant/components/influxdb/__init__.py index 0d1999e0d7b..9823d57e200 100644 --- a/homeassistant/components/influxdb/__init__.py +++ b/homeassistant/components/influxdb/__init__.py @@ -5,12 +5,17 @@ import queue import re import threading import time +from typing import Dict from influxdb import InfluxDBClient, exceptions +from influxdb_client import InfluxDBClient as InfluxDBClientV2 +from influxdb_client.client.write_api import ASYNCHRONOUS, SYNCHRONOUS +from influxdb_client.rest import ApiException import requests.exceptions import voluptuous as vol from homeassistant.const import ( + CONF_API_VERSION, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, @@ -20,6 +25,8 @@ from homeassistant.const import ( CONF_PATH, CONF_PORT, CONF_SSL, + CONF_TOKEN, + CONF_URL, CONF_USERNAME, CONF_VERIFY_SSL, EVENT_HOMEASSISTANT_STOP, @@ -34,6 +41,8 @@ from homeassistant.helpers.entity_values import EntityValues _LOGGER = logging.getLogger(__name__) CONF_DB_NAME = "database" +CONF_BUCKET = "bucket" +CONF_ORG = "organization" CONF_TAGS = "tags" CONF_DEFAULT_MEASUREMENT = "default_measurement" CONF_OVERRIDE_MEASUREMENT = "override_measurement" @@ -44,9 +53,14 @@ CONF_COMPONENT_CONFIG_DOMAIN = "component_config_domain" CONF_RETRY_COUNT = "max_retries" DEFAULT_DATABASE = "home_assistant" +DEFAULT_HOST_V2 = "us-west-2-1.aws.cloud2.influxdata.com" +DEFAULT_SSL_V2 = True +DEFAULT_BUCKET = "Home Assistant" DEFAULT_VERIFY_SSL = True -DOMAIN = "influxdb" +DEFAULT_API_VERSION = "1" +DOMAIN = "influxdb" +API_VERSION_2 = "2" TIMEOUT = 5 RETRY_DELAY = 20 QUEUE_BACKLOG_SECONDS = 30 @@ -55,62 +69,122 @@ RETRY_INTERVAL = 60 # seconds BATCH_TIMEOUT = 1 BATCH_BUFFER_SIZE = 100 -COMPONENT_CONFIG_SCHEMA_ENTRY = vol.Schema( - {vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string} +DB_CONNECTION_FAILURE_MSG = () + + +def create_influx_url(conf: Dict) -> Dict: + """Build URL used from config inputs and default when necessary.""" + if conf[CONF_API_VERSION] == API_VERSION_2: + if CONF_SSL not in conf: + conf[CONF_SSL] = DEFAULT_SSL_V2 + if CONF_HOST not in conf: + conf[CONF_HOST] = DEFAULT_HOST_V2 + + url = conf[CONF_HOST] + if conf[CONF_SSL]: + url = f"https://{url}" + else: + url = f"http://{url}" + + if CONF_PORT in conf: + url = f"{url}:{conf[CONF_PORT]}" + + if CONF_PATH in conf: + url = f"{url}{conf[CONF_PATH]}" + + conf[CONF_URL] = url + + return conf + + +def validate_version_specific_config(conf: Dict) -> Dict: + """Ensure correct config fields are provided based on API version used.""" + if conf[CONF_API_VERSION] == API_VERSION_2: + if CONF_TOKEN not in conf: + raise vol.Invalid( + f"{CONF_TOKEN} and {CONF_BUCKET} are required when {CONF_API_VERSION} is {API_VERSION_2}" + ) + + if CONF_USERNAME in conf: + raise vol.Invalid( + f"{CONF_USERNAME} and {CONF_PASSWORD} are only allowed when {CONF_API_VERSION} is {DEFAULT_API_VERSION}" + ) + + else: + if CONF_TOKEN in conf: + raise vol.Invalid( + f"{CONF_TOKEN} and {CONF_BUCKET} are only allowed when {CONF_API_VERSION} is {API_VERSION_2}" + ) + + return conf + + +COMPONENT_CONFIG_SCHEMA_CONNECTION = { + # Connection config for V1 and V2 APIs. + vol.Optional(CONF_API_VERSION, default=DEFAULT_API_VERSION): vol.All( + vol.Coerce(str), vol.In([DEFAULT_API_VERSION, API_VERSION_2]), + ), + vol.Optional(CONF_HOST): cv.string, + vol.Optional(CONF_PATH): cv.string, + vol.Optional(CONF_PORT): cv.port, + vol.Optional(CONF_SSL): cv.boolean, + # Connection config for V1 API only. + vol.Inclusive(CONF_USERNAME, "authentication"): cv.string, + vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string, + vol.Optional(CONF_DB_NAME, default=DEFAULT_DATABASE): cv.string, + vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, + # Connection config for V2 API only. + vol.Inclusive(CONF_TOKEN, "v2_authentication"): cv.string, + vol.Inclusive(CONF_ORG, "v2_authentication"): cv.string, + vol.Optional(CONF_BUCKET, default=DEFAULT_BUCKET): cv.string, +} + +_CONFIG_SCHEMA_ENTRY = vol.Schema({vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string}) + +_CONFIG_SCHEMA = vol.Schema( + { + vol.Optional(CONF_EXCLUDE, default={}): vol.Schema( + { + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): vol.All( + cv.ensure_list, [cv.string] + ), + } + ), + vol.Optional(CONF_INCLUDE, default={}): vol.Schema( + { + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): vol.All( + cv.ensure_list, [cv.string] + ), + } + ), + vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int, + vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string, + vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, + vol.Optional(CONF_TAGS, default={}): vol.Schema({cv.string: cv.string}), + vol.Optional(CONF_TAGS_ATTRIBUTES, default=[]): vol.All( + cv.ensure_list, [cv.string] + ), + vol.Optional(CONF_COMPONENT_CONFIG, default={}): vol.Schema( + {cv.entity_id: _CONFIG_SCHEMA_ENTRY} + ), + vol.Optional(CONF_COMPONENT_CONFIG_GLOB, default={}): vol.Schema( + {cv.string: _CONFIG_SCHEMA_ENTRY} + ), + vol.Optional(CONF_COMPONENT_CONFIG_DOMAIN, default={}): vol.Schema( + {cv.string: _CONFIG_SCHEMA_ENTRY} + ), + } ) CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.All( - vol.Schema( - { - vol.Optional(CONF_HOST): cv.string, - vol.Inclusive(CONF_USERNAME, "authentication"): cv.string, - vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string, - vol.Optional(CONF_EXCLUDE, default={}): vol.Schema( - { - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): vol.All( - cv.ensure_list, [cv.string] - ), - } - ), - vol.Optional(CONF_INCLUDE, default={}): vol.Schema( - { - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): vol.All( - cv.ensure_list, [cv.string] - ), - } - ), - vol.Optional(CONF_DB_NAME, default=DEFAULT_DATABASE): cv.string, - vol.Optional(CONF_PATH): cv.string, - vol.Optional(CONF_PORT): cv.port, - vol.Optional(CONF_SSL): cv.boolean, - vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int, - vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string, - vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, - vol.Optional(CONF_TAGS, default={}): vol.Schema( - {cv.string: cv.string} - ), - vol.Optional(CONF_TAGS_ATTRIBUTES, default=[]): vol.All( - cv.ensure_list, [cv.string] - ), - vol.Optional( - CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL - ): cv.boolean, - vol.Optional(CONF_COMPONENT_CONFIG, default={}): vol.Schema( - {cv.entity_id: COMPONENT_CONFIG_SCHEMA_ENTRY} - ), - vol.Optional(CONF_COMPONENT_CONFIG_GLOB, default={}): vol.Schema( - {cv.string: COMPONENT_CONFIG_SCHEMA_ENTRY} - ), - vol.Optional(CONF_COMPONENT_CONFIG_DOMAIN, default={}): vol.Schema( - {cv.string: COMPONENT_CONFIG_SCHEMA_ENTRY} - ), - } - ) - ) + _CONFIG_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION), + validate_version_specific_config, + create_influx_url, + ), }, extra=vol.ALLOW_EXTRA, ) @@ -119,34 +193,65 @@ RE_DIGIT_TAIL = re.compile(r"^[^\.]*\d+\.?\d+[^\.]*$") RE_DECIMAL = re.compile(r"[^\d.]+") +def get_influx_connection(client_kwargs, bucket): + """Create and check the correct influx connection for the API version.""" + if bucket is not None: + # Test connection by synchronously writing nothing. + # If config is valid this will generate a `Bad Request` exception but not make anything. + # If config is invalid we will output an error. + # Hopefully a better way to test connection is added in the future. + try: + influx = InfluxDBClientV2(**client_kwargs) + influx.write_api(write_options=SYNCHRONOUS).write(bucket=bucket) + + except ApiException as exc: + # 400 is the success state since it means we can write we just gave a bad point. + if exc.status != 400: + raise exc + + else: + influx = InfluxDBClient(**client_kwargs) + influx.write_points([]) + + return influx + + def setup(hass, config): """Set up the InfluxDB component.""" - conf = config[DOMAIN] - + use_v2_api = conf[CONF_API_VERSION] == API_VERSION_2 + bucket = None kwargs = { - "database": conf[CONF_DB_NAME], - "verify_ssl": conf[CONF_VERIFY_SSL], "timeout": TIMEOUT, } - if CONF_HOST in conf: - kwargs["host"] = conf[CONF_HOST] + if use_v2_api: + kwargs["url"] = conf[CONF_URL] + kwargs["token"] = conf[CONF_TOKEN] + kwargs["org"] = conf[CONF_ORG] + bucket = conf[CONF_BUCKET] - if CONF_PATH in conf: - kwargs["path"] = conf[CONF_PATH] + else: + kwargs["database"] = conf[CONF_DB_NAME] + kwargs["verify_ssl"] = conf[CONF_VERIFY_SSL] - if CONF_PORT in conf: - kwargs["port"] = conf[CONF_PORT] + if CONF_USERNAME in conf: + kwargs["username"] = conf[CONF_USERNAME] - if CONF_USERNAME in conf: - kwargs["username"] = conf[CONF_USERNAME] + if CONF_PASSWORD in conf: + kwargs["password"] = conf[CONF_PASSWORD] - if CONF_PASSWORD in conf: - kwargs["password"] = conf[CONF_PASSWORD] + if CONF_HOST in conf: + kwargs["host"] = conf[CONF_HOST] - if CONF_SSL in conf: - kwargs["ssl"] = conf[CONF_SSL] + if CONF_PATH in conf: + kwargs["path"] = conf[CONF_PATH] + + if CONF_PORT in conf: + kwargs["port"] = conf[CONF_PORT] + + if CONF_SSL in conf: + kwargs["ssl"] = conf[CONF_SSL] include = conf.get(CONF_INCLUDE, {}) exclude = conf.get(CONF_EXCLUDE, {}) @@ -166,10 +271,11 @@ def setup(hass, config): max_tries = conf.get(CONF_RETRY_COUNT) try: - influx = InfluxDBClient(**kwargs) - influx.write_points([]) + influx = get_influx_connection(kwargs, bucket) + if use_v2_api: + write_api = influx.write_api(write_options=ASYNCHRONOUS) except (exceptions.InfluxDBClientError, requests.exceptions.ConnectionError) as exc: - _LOGGER.warning( + _LOGGER.error( "Database host is not accessible due to '%s', please " "check your entries in the configuration file (host, " "port, etc.) and verify that the database exists and is " @@ -179,6 +285,17 @@ def setup(hass, config): ) event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) return True + except ApiException as exc: + _LOGGER.error( + "Bucket is not accessible due to '%s', please " + "check your entries in the configuration file (url, org, " + "bucket, etc.) and verify that the org and bucket exist and the " + "provided token has WRITE access. Retrying again in %s seconds.", + exc, + RETRY_INTERVAL, + ) + event_helper.call_later(hass, RETRY_INTERVAL, lambda _: setup(hass, config)) + return True def event_to_json(event): """Add an event to the outgoing Influx list.""" @@ -270,7 +387,15 @@ def setup(hass, config): return json - instance = hass.data[DOMAIN] = InfluxThread(hass, influx, event_to_json, max_tries) + if use_v2_api: + instance = hass.data[DOMAIN] = InfluxThread( + hass, None, bucket, write_api, event_to_json, max_tries + ) + else: + instance = hass.data[DOMAIN] = InfluxThread( + hass, influx, None, None, event_to_json, max_tries + ) + instance.start() def shutdown(event): @@ -287,11 +412,13 @@ def setup(hass, config): class InfluxThread(threading.Thread): """A threaded event handler class.""" - def __init__(self, hass, influx, event_to_json, max_tries): + def __init__(self, hass, influx, bucket, write_api, event_to_json, max_tries): """Initialize the listener.""" threading.Thread.__init__(self, name="InfluxDB") self.queue = queue.Queue() self.influx = influx + self.bucket = bucket + self.write_api = write_api self.event_to_json = event_to_json self.max_tries = max_tries self.write_errors = 0 @@ -346,10 +473,12 @@ class InfluxThread(threading.Thread): def write_to_influxdb(self, json): """Write preprocessed events to influxdb, with retry.""" - for retry in range(self.max_tries + 1): try: - self.influx.write_points(json) + if self.write_api is not None: + self.write_api.write(bucket=self.bucket, record=json) + else: + self.influx.write_points(json) if self.write_errors: _LOGGER.error("Resumed, lost %d events", self.write_errors) @@ -361,6 +490,7 @@ class InfluxThread(threading.Thread): exceptions.InfluxDBClientError, exceptions.InfluxDBServerError, OSError, + ApiException, ) as err: if retry < self.max_tries: time.sleep(RETRY_DELAY) diff --git a/homeassistant/components/influxdb/manifest.json b/homeassistant/components/influxdb/manifest.json index 94577f5735f..596c0ecc6ce 100644 --- a/homeassistant/components/influxdb/manifest.json +++ b/homeassistant/components/influxdb/manifest.json @@ -2,6 +2,6 @@ "domain": "influxdb", "name": "InfluxDB", "documentation": "https://www.home-assistant.io/integrations/influxdb", - "requirements": ["influxdb==5.2.3"], + "requirements": ["influxdb==5.2.3", "influxdb-client==1.6.0"], "codeowners": ["@fabaff"] } diff --git a/homeassistant/components/influxdb/sensor.py b/homeassistant/components/influxdb/sensor.py index 64ab1174b8b..0cf25c0b2f4 100644 --- a/homeassistant/components/influxdb/sensor.py +++ b/homeassistant/components/influxdb/sensor.py @@ -1,18 +1,25 @@ """InfluxDB component which allows you to get data from an Influx database.""" from datetime import timedelta import logging +from typing import Dict from influxdb import InfluxDBClient, exceptions +from influxdb_client import InfluxDBClient as InfluxDBClientV2 +from influxdb_client.rest import ApiException import voluptuous as vol from homeassistant.components.sensor import PLATFORM_SCHEMA from homeassistant.const import ( + CONF_API_VERSION, CONF_HOST, CONF_NAME, CONF_PASSWORD, + CONF_PATH, CONF_PORT, CONF_SSL, + CONF_TOKEN, CONF_UNIT_OF_MEASUREMENT, + CONF_URL, CONF_USERNAME, CONF_VALUE_TEMPLATE, CONF_VERIFY_SSL, @@ -23,79 +30,161 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import Entity from homeassistant.util import Throttle -from . import CONF_DB_NAME +from . import ( + API_VERSION_2, + COMPONENT_CONFIG_SCHEMA_CONNECTION, + CONF_BUCKET, + CONF_DB_NAME, + CONF_ORG, + DEFAULT_API_VERSION, + create_influx_url, + validate_version_specific_config, +) _LOGGER = logging.getLogger(__name__) -DEFAULT_HOST = "localhost" -DEFAULT_PORT = 8086 -DEFAULT_DATABASE = "home_assistant" -DEFAULT_SSL = False -DEFAULT_VERIFY_SSL = False DEFAULT_GROUP_FUNCTION = "mean" DEFAULT_FIELD = "value" CONF_QUERIES = "queries" +CONF_QUERIES_FLUX = "queries_flux" CONF_GROUP_FUNCTION = "group_function" CONF_FIELD = "field" CONF_MEASUREMENT_NAME = "measurement" CONF_WHERE = "where" +CONF_RANGE_START = "range_start" +CONF_RANGE_STOP = "range_stop" +CONF_FUNCTION = "function" +CONF_QUERY = "query" +CONF_IMPORTS = "imports" + +DEFAULT_RANGE_START = "-15m" +DEFAULT_RANGE_STOP = "now()" + MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=60) -_QUERY_SCHEME = vol.Schema( +_QUERY_SENSOR_SCHEMA = vol.Schema( { vol.Required(CONF_NAME): cv.string, - vol.Required(CONF_MEASUREMENT_NAME): cv.string, - vol.Required(CONF_WHERE): cv.template, - vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_DB_NAME, default=DEFAULT_DATABASE): cv.string, - vol.Optional(CONF_GROUP_FUNCTION, default=DEFAULT_GROUP_FUNCTION): cv.string, - vol.Optional(CONF_FIELD, default=DEFAULT_FIELD): cv.string, + vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, } ) -PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( - { - vol.Required(CONF_QUERIES): [_QUERY_SCHEME], - vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string, - vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, - vol.Inclusive(CONF_USERNAME, "authentication"): cv.string, - vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string, - vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, - vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, - } +_QUERY_SCHEMA = { + "InfluxQL": _QUERY_SENSOR_SCHEMA.extend( + { + vol.Optional(CONF_DB_NAME): cv.string, + vol.Required(CONF_MEASUREMENT_NAME): cv.string, + vol.Optional( + CONF_GROUP_FUNCTION, default=DEFAULT_GROUP_FUNCTION + ): cv.string, + vol.Optional(CONF_FIELD, default=DEFAULT_FIELD): cv.string, + vol.Required(CONF_WHERE): cv.template, + } + ), + "Flux": _QUERY_SENSOR_SCHEMA.extend( + { + vol.Optional(CONF_BUCKET): cv.string, + vol.Optional(CONF_RANGE_START, default=DEFAULT_RANGE_START): cv.string, + vol.Optional(CONF_RANGE_STOP, default=DEFAULT_RANGE_STOP): cv.string, + vol.Required(CONF_QUERY): cv.template, + vol.Optional(CONF_IMPORTS): vol.All(cv.ensure_list, [cv.string]), + vol.Optional(CONF_GROUP_FUNCTION): cv.string, + } + ), +} + + +def validate_query_format_for_version(conf: Dict) -> Dict: + """Ensure queries are provided in correct format based on API version.""" + if conf[CONF_API_VERSION] == API_VERSION_2: + if CONF_QUERIES_FLUX not in conf: + raise vol.Invalid( + f"{CONF_QUERIES_FLUX} is required when {CONF_API_VERSION} is {API_VERSION_2}" + ) + + else: + if CONF_QUERIES not in conf: + raise vol.Invalid( + f"{CONF_QUERIES} is required when {CONF_API_VERSION} is {DEFAULT_API_VERSION}" + ) + + return conf + + +PLATFORM_SCHEMA = vol.All( + PLATFORM_SCHEMA.extend(COMPONENT_CONFIG_SCHEMA_CONNECTION).extend( + { + vol.Exclusive(CONF_QUERIES, "queries"): [_QUERY_SCHEMA["InfluxQL"]], + vol.Exclusive(CONF_QUERIES_FLUX, "queries"): [_QUERY_SCHEMA["Flux"]], + } + ), + validate_version_specific_config, + validate_query_format_for_version, + create_influx_url, ) def setup_platform(hass, config, add_entities, discovery_info=None): """Set up the InfluxDB component.""" - influx_conf = { - "host": config[CONF_HOST], - "password": config.get(CONF_PASSWORD), - "port": config.get(CONF_PORT), - "ssl": config[CONF_SSL], - "username": config.get(CONF_USERNAME), - "verify_ssl": config.get(CONF_VERIFY_SSL), - } + use_v2_api = config[CONF_API_VERSION] == API_VERSION_2 + queries = None - dev = [] + if use_v2_api: + influx_conf = { + "url": config[CONF_URL], + "token": config[CONF_TOKEN], + "org": config[CONF_ORG], + } + bucket = config[CONF_BUCKET] + queries = config[CONF_QUERIES_FLUX] - for query in config.get(CONF_QUERIES): - sensor = InfluxSensor(hass, influx_conf, query) + for v2_query in queries: + if CONF_BUCKET not in v2_query: + v2_query[CONF_BUCKET] = bucket + + else: + influx_conf = { + "database": config[CONF_DB_NAME], + "verify_ssl": config[CONF_VERIFY_SSL], + } + + if CONF_USERNAME in config: + influx_conf["username"] = config[CONF_USERNAME] + + if CONF_PASSWORD in config: + influx_conf["password"] = config[CONF_PASSWORD] + + if CONF_HOST in config: + influx_conf["host"] = config[CONF_HOST] + + if CONF_PATH in config: + influx_conf["path"] = config[CONF_PATH] + + if CONF_PORT in config: + influx_conf["port"] = config[CONF_PORT] + + if CONF_SSL in config: + influx_conf["ssl"] = config[CONF_SSL] + + queries = config[CONF_QUERIES] + + entities = [] + for query in queries: + sensor = InfluxSensor(hass, influx_conf, query, use_v2_api) if sensor.connected: - dev.append(sensor) + entities.append(sensor) - add_entities(dev, True) + add_entities(entities, True) class InfluxSensor(Entity): """Implementation of a Influxdb sensor.""" - def __init__(self, hass, influx_conf, query): + def __init__(self, hass, influx_conf, query, use_v2_api): """Initialize the sensor.""" - self._name = query.get(CONF_NAME) self._unit_of_measurement = query.get(CONF_UNIT_OF_MEASUREMENT) value_template = query.get(CONF_VALUE_TEMPLATE) @@ -104,32 +193,54 @@ class InfluxSensor(Entity): self._value_template.hass = hass else: self._value_template = None - database = query.get(CONF_DB_NAME) self._state = None self._hass = hass - where_clause = query.get(CONF_WHERE) - where_clause.hass = hass + if use_v2_api: + influx = InfluxDBClientV2(**influx_conf) + query_api = influx.query_api() + query_clause = query.get(CONF_QUERY) + query_clause.hass = hass + bucket = query[CONF_BUCKET] + + else: + if CONF_DB_NAME in query: + kwargs = influx_conf.copy() + kwargs[CONF_DB_NAME] = query[CONF_DB_NAME] + else: + kwargs = influx_conf + + influx = InfluxDBClient(**kwargs) + where_clause = query.get(CONF_WHERE) + where_clause.hass = hass + query_api = None - influx = InfluxDBClient( - host=influx_conf["host"], - port=influx_conf["port"], - username=influx_conf["username"], - password=influx_conf["password"], - database=database, - ssl=influx_conf["ssl"], - verify_ssl=influx_conf["verify_ssl"], - ) try: - influx.query("SHOW SERIES LIMIT 1;") - self.connected = True - self.data = InfluxSensorData( - influx, - query.get(CONF_GROUP_FUNCTION), - query.get(CONF_FIELD), - query.get(CONF_MEASUREMENT_NAME), - where_clause, - ) + if query_api is not None: + query_api.query( + f'from(bucket: "{bucket}") |> range(start: -1ms) |> keep(columns: ["_time"]) |> limit(n: 1)' + ) + self.connected = True + self.data = InfluxSensorDataV2( + query_api, + bucket, + query.get(CONF_RANGE_START), + query.get(CONF_RANGE_STOP), + query_clause, + query.get(CONF_IMPORTS), + query.get(CONF_GROUP_FUNCTION), + ) + + else: + influx.query("SHOW SERIES LIMIT 1;") + self.connected = True + self.data = InfluxSensorDataV1( + influx, + query.get(CONF_GROUP_FUNCTION), + query.get(CONF_FIELD), + query.get(CONF_MEASUREMENT_NAME), + where_clause, + ) except exceptions.InfluxDBClientError as exc: _LOGGER.error( "Database host is not accessible due to '%s', please" @@ -138,6 +249,15 @@ class InfluxSensor(Entity): exc, ) self.connected = False + except ApiException as exc: + _LOGGER.error( + "Bucket is not accessible due to '%s', please " + "check your entries in the configuration file (url, org, " + "bucket, etc.) and verify that the org and bucket exist and the " + "provided token has READ access.", + exc, + ) + self.connected = False @property def name(self): @@ -173,8 +293,76 @@ class InfluxSensor(Entity): self._state = value -class InfluxSensorData: - """Class for handling the data retrieval.""" +class InfluxSensorDataV2: + """Class for handling the data retrieval with v2 API.""" + + def __init__( + self, query_api, bucket, range_start, range_stop, query, imports, group + ): + """Initialize the data object.""" + self.query_api = query_api + self.bucket = bucket + self.range_start = range_start + self.range_stop = range_stop + self.query = query + self.imports = imports + self.group = group + self.value = None + self.full_query = None + + self.query_prefix = f'from(bucket:"{bucket}") |> range(start: {range_start}, stop: {range_stop}) |>' + if imports is not None: + for i in imports: + self.query_prefix = f'import "{i}" {self.query_prefix}' + + if group is None: + self.query_postfix = "|> limit(n: 1)" + else: + self.query_postfix = f'|> {group}(column: "_value")' + + @Throttle(MIN_TIME_BETWEEN_UPDATES) + def update(self): + """Get the latest data by querying influx.""" + _LOGGER.debug("Rendering query: %s", self.query) + try: + rendered_query = self.query.render() + except TemplateError as ex: + _LOGGER.error("Could not render query template: %s", ex) + return + + self.full_query = f"{self.query_prefix} {rendered_query} {self.query_postfix}" + + _LOGGER.info("Running query: %s", self.full_query) + + try: + tables = self.query_api.query(self.full_query) + except ApiException as exc: + _LOGGER.error( + "Could not execute query '%s' due to '%s', " + "Check the syntax of your query", + self.full_query, + exc, + ) + self.value = None + return + + if not tables: + _LOGGER.warning( + "Query returned no results, sensor state set to UNKNOWN: %s", + self.full_query, + ) + self.value = None + else: + if len(tables) > 1: + _LOGGER.warning( + "Query returned multiple tables, only value from first one is shown: %s", + self.full_query, + ) + self.value = tables[0].records[0].values["_value"] + + +class InfluxSensorDataV1: + """Class for handling the data retrieval with v1 API.""" def __init__(self, influx, group, field, measurement, where): """Initialize the data object.""" @@ -200,7 +388,18 @@ class InfluxSensorData: _LOGGER.info("Running query: %s", self.query) - points = list(self.influx.query(self.query).get_points()) + try: + points = list(self.influx.query(self.query).get_points()) + except exceptions.InfluxDBClientError as exc: + _LOGGER.error( + "Could not execute query '%s' due to '%s', " + "Check the syntax of your query", + self.query, + exc, + ) + self.value = None + return + if not points: _LOGGER.warning( "Query returned no points, sensor state set to UNKNOWN: %s", self.query diff --git a/requirements_all.txt b/requirements_all.txt index b92941877be..e8aeba906f4 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -784,6 +784,9 @@ ihcsdk==2.7.0 # homeassistant.components.incomfort incomfort-client==0.4.0 +# homeassistant.components.influxdb +influxdb-client==1.6.0 + # homeassistant.components.influxdb influxdb==5.2.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 2f494c40c76..d4e43e64f1b 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -345,6 +345,9 @@ huawei-lte-api==1.4.12 # homeassistant.components.iaqualink iaqualink==0.3.4 +# homeassistant.components.influxdb +influxdb-client==1.6.0 + # homeassistant.components.influxdb influxdb==5.2.3 diff --git a/tests/components/influxdb/test_init.py b/tests/components/influxdb/test_init.py index cdbb39ba3ce..f9514f7ebff 100644 --- a/tests/components/influxdb/test_init.py +++ b/tests/components/influxdb/test_init.py @@ -1,7 +1,7 @@ """The tests for the InfluxDB component.""" import datetime -import unittest -from unittest import mock + +import pytest import homeassistant.components.influxdb as influxdb from homeassistant.const import ( @@ -11,749 +11,1103 @@ from homeassistant.const import ( STATE_STANDBY, UNIT_PERCENTAGE, ) -from homeassistant.setup import setup_component +from homeassistant.setup import async_setup_component -from tests.common import get_test_home_assistant +from tests.async_mock import MagicMock, Mock, call, patch + +BASE_V1_CONFIG = {} +BASE_V2_CONFIG = { + "api_version": influxdb.API_VERSION_2, + "organization": "org", + "token": "token", +} -@mock.patch("homeassistant.components.influxdb.InfluxDBClient") -@mock.patch( - "homeassistant.components.influxdb.InfluxThread.batch_timeout", - mock.Mock(return_value=0), -) -class TestInfluxDB(unittest.TestCase): - """Test the InfluxDB component.""" +@pytest.fixture(autouse=True) +def mock_batch_timeout(hass, monkeypatch): + """Mock the event bus listener and the batch timeout for tests.""" + hass.bus.listen = MagicMock() + monkeypatch.setattr( + "homeassistant.components.influxdb.InfluxThread.batch_timeout", + Mock(return_value=0), + ) - def setUp(self): - """Set up things to be run when tests are started.""" - self.hass = get_test_home_assistant() - self.handler_method = None - self.hass.bus.listen = mock.Mock() - self.addCleanup(self.tear_down_cleanup) - def tear_down_cleanup(self): - """Clear data.""" - self.hass.stop() +@pytest.fixture(name="mock_client") +def mock_client_fixture(request): + """Patch the InfluxDBClient object with mock for version under test.""" + if request.param == influxdb.API_VERSION_2: + client_target = "homeassistant.components.influxdb.InfluxDBClientV2" + else: + client_target = "homeassistant.components.influxdb.InfluxDBClient" - def test_setup_config_full(self, mock_client): - """Test the setup with full configuration.""" - config = { - "influxdb": { - "host": "host", - "port": 123, - "database": "db", + with patch(client_target) as client: + yield client + + +@pytest.fixture(name="get_mock_call") +def get_mock_call_fixture(request): + """Get version specific lambda to make write API call mock.""" + if request.param == influxdb.API_VERSION_2: + return lambda body: call(bucket=influxdb.DEFAULT_BUCKET, record=body) + # pylint: disable=unnecessary-lambda + return lambda body: call(body) + + +def _get_write_api_mock_v1(mock_influx_client): + """Return the write api mock for the V1 client.""" + return mock_influx_client.return_value.write_points + + +def _get_write_api_mock_v2(mock_influx_client): + """Return the write api mock for the V2 client.""" + return mock_influx_client.return_value.write_api.return_value.write + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api", + [ + ( + influxdb.DEFAULT_API_VERSION, + { + "api_version": influxdb.DEFAULT_API_VERSION, "username": "user", "password": "password", - "max_retries": 4, - "ssl": "False", "verify_ssl": "False", - } + }, + _get_write_api_mock_v1, + ), + ( + influxdb.API_VERSION_2, + { + "api_version": influxdb.API_VERSION_2, + "token": "token", + "organization": "organization", + "bucket": "bucket", + }, + _get_write_api_mock_v2, + ), + ], + indirect=["mock_client"], +) +async def test_setup_config_full(hass, mock_client, config_ext, get_write_api): + """Test the setup with full configuration.""" + config = { + "influxdb": { + "host": "host", + "port": 123, + "database": "db", + "max_retries": 4, + "ssl": "False", } - assert setup_component(self.hass, influxdb.DOMAIN, config) - assert self.hass.bus.listen.called - assert EVENT_STATE_CHANGED == self.hass.bus.listen.call_args_list[0][0][0] - assert mock_client.return_value.write_points.call_count == 1 + } + config["influxdb"].update(config_ext) - def test_setup_config_defaults(self, mock_client): - """Test the setup with default configuration.""" - config = {"influxdb": {"host": "host", "username": "user", "password": "pass"}} - assert setup_component(self.hass, influxdb.DOMAIN, config) - assert self.hass.bus.listen.called - assert EVENT_STATE_CHANGED == self.hass.bus.listen.call_args_list[0][0][0] + assert await async_setup_component(hass, influxdb.DOMAIN, config) + await hass.async_block_till_done() + assert hass.bus.listen.called + assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0] + assert get_write_api(mock_client).call_count == 1 - def test_setup_minimal_config(self, mock_client): - """Test the setup with minimal configuration.""" - config = {"influxdb": {}} - assert setup_component(self.hass, influxdb.DOMAIN, config) +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api", + [ + (influxdb.DEFAULT_API_VERSION, BASE_V1_CONFIG, _get_write_api_mock_v1), + (influxdb.API_VERSION_2, BASE_V2_CONFIG, _get_write_api_mock_v2), + ], + indirect=["mock_client"], +) +async def test_setup_minimal_config(hass, mock_client, config_ext, get_write_api): + """Test the setup with minimal configuration and defaults.""" + config = {"influxdb": {}} + config["influxdb"].update(config_ext) - def test_setup_missing_password(self, mock_client): - """Test the setup with existing username and missing password.""" - config = {"influxdb": {"username": "user"}} + assert await async_setup_component(hass, influxdb.DOMAIN, config) + await hass.async_block_till_done() + assert hass.bus.listen.called + assert EVENT_STATE_CHANGED == hass.bus.listen.call_args_list[0][0][0] + assert get_write_api(mock_client).call_count == 1 - assert not setup_component(self.hass, influxdb.DOMAIN, config) - def _setup(self, mock_client, **kwargs): - """Set up the client.""" - config = { - "influxdb": { - "host": "host", +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api", + [ + (influxdb.DEFAULT_API_VERSION, {"username": "user"}, _get_write_api_mock_v1), + (influxdb.DEFAULT_API_VERSION, {"token": "token"}, _get_write_api_mock_v1), + ( + influxdb.API_VERSION_2, + {"api_version": influxdb.API_VERSION_2, "organization": "organization"}, + _get_write_api_mock_v2, + ), + ( + influxdb.API_VERSION_2, + { + "api_version": influxdb.API_VERSION_2, + "token": "token", + "organization": "organization", "username": "user", "password": "pass", - "exclude": { - "entities": ["fake.blacklisted"], - "domains": ["another_fake"], - }, - } + }, + _get_write_api_mock_v2, + ), + ], + indirect=["mock_client"], +) +async def test_invalid_config(hass, mock_client, config_ext, get_write_api): + """Test the setup with invalid config or config options specified for wrong version.""" + config = {"influxdb": {}} + config["influxdb"].update(config_ext) + + assert not await async_setup_component(hass, influxdb.DOMAIN, config) + + +async def _setup(hass, mock_influx_client, config_ext, get_write_api): + """Prepare client for next test and return event handler method.""" + config = { + "influxdb": { + "host": "host", + "exclude": {"entities": ["fake.blacklisted"], "domains": ["another_fake"]}, } - config["influxdb"].update(kwargs) - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + } + config["influxdb"].update(config_ext) + assert await async_setup_component(hass, influxdb.DOMAIN, config) + await hass.async_block_till_done() + # A call is made to the write API during setup to test the connection. + # Therefore we reset the write API mock here before the test begins. + get_write_api(mock_influx_client).reset_mock() + return hass.bus.listen.call_args_list[0][0][1] - def test_event_listener(self, mock_client): - """Test the event listener.""" - self._setup(mock_client) - # map of HA State to valid influxdb [state, value] fields - valid = { - "1": [None, 1], - "1.0": [None, 1.0], - STATE_ON: [STATE_ON, 1], - STATE_OFF: [STATE_OFF, 0], - STATE_STANDBY: [STATE_STANDBY, None], - "foo": ["foo", None], +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) + + # map of HA State to valid influxdb [state, value] fields + valid = { + "1": [None, 1], + "1.0": [None, 1.0], + STATE_ON: [STATE_ON, 1], + STATE_OFF: [STATE_OFF, 0], + STATE_STANDBY: [STATE_STANDBY, None], + "foo": ["foo", None], + } + for in_, out in valid.items(): + attrs = { + "unit_of_measurement": "foobars", + "longitude": "1.1", + "latitude": "2.2", + "battery_level": f"99{UNIT_PERCENTAGE}", + "temperature": "20c", + "last_seen": "Last seen 23 minutes ago", + "updated_at": datetime.datetime(2017, 1, 1, 0, 0), + "multi_periods": "0.120.240.2023873", } - for in_, out in valid.items(): - attrs = { - "unit_of_measurement": "foobars", - "longitude": "1.1", - "latitude": "2.2", - "battery_level": f"99{UNIT_PERCENTAGE}", - "temperature": "20c", - "last_seen": "Last seen 23 minutes ago", - "updated_at": datetime.datetime(2017, 1, 1, 0, 0), - "multi_periods": "0.120.240.2023873", - } - state = mock.MagicMock( - state=in_, - domain="fake", - entity_id="fake.entity-id", - object_id="entity", - attributes=attrs, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": "foobars", - "tags": {"domain": "fake", "entity_id": "entity"}, - "time": 12345, - "fields": { - "longitude": 1.1, - "latitude": 2.2, - "battery_level_str": f"99{UNIT_PERCENTAGE}", - "battery_level": 99.0, - "temperature_str": "20c", - "temperature": 20.0, - "last_seen_str": "Last seen 23 minutes ago", - "last_seen": 23.0, - "updated_at_str": "2017-01-01 00:00:00", - "updated_at": 20170101000000, - "multi_periods_str": "0.120.240.2023873", - }, - } - ] - if out[0] is not None: - body[0]["fields"]["state"] = out[0] - if out[1] is not None: - body[0]["fields"]["value"] = out[1] - - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() - - def test_event_listener_no_units(self, mock_client): - """Test the event listener for missing units.""" - self._setup(mock_client) - - for unit in (None, ""): - if unit: - attrs = {"unit_of_measurement": unit} - else: - attrs = {} - state = mock.MagicMock( - state=1, - domain="fake", - entity_id="fake.entity-id", - object_id="entity", - attributes=attrs, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": "fake.entity-id", - "tags": {"domain": "fake", "entity_id": "entity"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() - - def test_event_listener_inf(self, mock_client): - """Test the event listener for missing units.""" - self._setup(mock_client) - - attrs = {"bignumstring": "9" * 999, "nonumstring": "nan"} - state = mock.MagicMock( - state=8, + state = MagicMock( + state=in_, domain="fake", entity_id="fake.entity-id", object_id="entity", attributes=attrs, ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "foobars", + "tags": {"domain": "fake", "entity_id": "entity"}, + "time": 12345, + "fields": { + "longitude": 1.1, + "latitude": 2.2, + "battery_level_str": f"99{UNIT_PERCENTAGE}", + "battery_level": 99.0, + "temperature_str": "20c", + "temperature": 20.0, + "last_seen_str": "Last seen 23 minutes ago", + "last_seen": 23.0, + "updated_at_str": "2017-01-01 00:00:00", + "updated_at": 20170101000000, + "multi_periods_str": "0.120.240.2023873", + }, + } + ] + if out[0] is not None: + body[0]["fields"]["state"] = out[0] + if out[1] is not None: + body[0]["fields"]["value"] = out[1] + + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + write_api.reset_mock() + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_no_units( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener for missing units.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) + + for unit in (None, ""): + if unit: + attrs = {"unit_of_measurement": unit} + else: + attrs = {} + state = MagicMock( + state=1, + domain="fake", + entity_id="fake.entity-id", + object_id="entity", + attributes=attrs, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) body = [ { "measurement": "fake.entity-id", "tags": {"domain": "fake", "entity_id": "entity"}, "time": 12345, - "fields": {"value": 8}, + "fields": {"value": 1}, } ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - def test_event_listener_states(self, mock_client): - """Test the event listener against ignored states.""" - self._setup(mock_client) + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + write_api.reset_mock() - for state_state in (1, "unknown", "", "unavailable"): - state = mock.MagicMock( - state=state_state, - domain="fake", - entity_id="fake.entity-id", - object_id="entity", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": "fake.entity-id", - "tags": {"domain": "fake", "entity_id": "entity"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if state_state == 1: - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() - def test_event_listener_blacklist(self, mock_client): - """Test the event listener against a blacklist.""" - self._setup(mock_client) +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_inf( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener with large or invalid numbers.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) - for entity_id in ("ok", "blacklisted"): - state = mock.MagicMock( - state=1, - domain="fake", - entity_id=f"fake.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"fake.{entity_id}", - "tags": {"domain": "fake", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if entity_id == "ok": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() - - def test_event_listener_blacklist_domain(self, mock_client): - """Test the event listener against a blacklist.""" - self._setup(mock_client) - - for domain in ("ok", "another_fake"): - state = mock.MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if domain == "ok": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() - - def test_event_listener_whitelist(self, mock_client): - """Test the event listener against a whitelist.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "include": {"entities": ["fake.included"]}, - } + attrs = {"bignumstring": "9" * 999, "nonumstring": "nan"} + state = MagicMock( + state=8, + domain="fake", + entity_id="fake.entity-id", + object_id="entity", + attributes=attrs, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "fake.entity-id", + "tags": {"domain": "fake", "entity_id": "entity"}, + "time": 12345, + "fields": {"value": 8}, } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - for entity_id in ("included", "default"): - state = mock.MagicMock( - state=1, - domain="fake", - entity_id=f"fake.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"fake.{entity_id}", - "tags": {"domain": "fake", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if entity_id == "included": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) - def test_event_listener_whitelist_domain(self, mock_client): - """Test the event listener against a whitelist.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "include": {"domains": ["fake"]}, + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_states( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against ignored states.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) + + for state_state in (1, "unknown", "", "unavailable"): + state = MagicMock( + state=state_state, + domain="fake", + entity_id="fake.entity-id", + object_id="entity", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "fake.entity-id", + "tags": {"domain": "fake", "entity_id": "entity"}, + "time": 12345, + "fields": {"value": 1}, } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - for domain in ("fake", "another_fake"): - state = mock.MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if domain == "fake": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + if state_state == 1: + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() - def test_event_listener_whitelist_domain_and_entities(self, mock_client): - """Test the event listener against a whitelist.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "include": {"domains": ["fake"], "entities": ["other.one"]}, + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_blacklist( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a blacklist.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) + + for entity_id in ("ok", "blacklisted"): + state = MagicMock( + state=1, + domain="fake", + entity_id=f"fake.{entity_id}", + object_id=entity_id, + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"fake.{entity_id}", + "tags": {"domain": "fake", "entity_id": entity_id}, + "time": 12345, + "fields": {"value": 1}, } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - for domain in ("fake", "another_fake"): - state = mock.MagicMock( - state=1, - domain=domain, - entity_id=f"{domain}.something", - object_id="something", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"{domain}.something", - "tags": {"domain": domain, "entity_id": "something"}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if domain == "fake": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + if entity_id == "ok": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() - for entity_id in ("one", "two"): - state = mock.MagicMock( - state=1, - domain="other", - entity_id=f"other.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": f"other.{entity_id}", - "tags": {"domain": "other", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if entity_id == "one": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() - def test_event_listener_invalid_type(self, mock_client): - """Test the event listener when an attribute has an invalid type.""" - self._setup(mock_client) +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_blacklist_domain( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a domain blacklist.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) - # map of HA State to valid influxdb [state, value] fields - valid = { - "1": [None, 1], - "1.0": [None, 1.0], - STATE_ON: [STATE_ON, 1], - STATE_OFF: [STATE_OFF, 0], - STATE_STANDBY: [STATE_STANDBY, None], - "foo": ["foo", None], - } - for in_, out in valid.items(): - attrs = { - "unit_of_measurement": "foobars", - "longitude": "1.1", - "latitude": "2.2", - "invalid_attribute": ["value1", "value2"], + for domain in ("ok", "another_fake"): + state = MagicMock( + state=1, + domain=domain, + entity_id=f"{domain}.something", + object_id="something", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"{domain}.something", + "tags": {"domain": domain, "entity_id": "something"}, + "time": 12345, + "fields": {"value": 1}, } - state = mock.MagicMock( - state=in_, - domain="fake", - entity_id="fake.entity-id", - object_id="entity", - attributes=attrs, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": "foobars", - "tags": {"domain": "fake", "entity_id": "entity"}, - "time": 12345, - "fields": { - "longitude": 1.1, - "latitude": 2.2, - "invalid_attribute_str": "['value1', 'value2']", - }, - } - ] - if out[0] is not None: - body[0]["fields"]["state"] = out[0] - if out[1] is not None: - body[0]["fields"]["value"] = out[1] + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + if domain == "ok": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() - def test_event_listener_default_measurement(self, mock_client): - """Test the event listener with a default measurement.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "default_measurement": "state", - "exclude": {"entities": ["fake.blacklisted"]}, + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_whitelist( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a whitelist.""" + config = {"include": {"entities": ["fake.included"]}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + for entity_id in ("included", "default"): + state = MagicMock( + state=1, + domain="fake", + entity_id=f"fake.{entity_id}", + object_id=entity_id, + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"fake.{entity_id}", + "tags": {"domain": "fake", "entity_id": entity_id}, + "time": 12345, + "fields": {"value": 1}, } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - for entity_id in ("ok", "blacklisted"): - state = mock.MagicMock( - state=1, - domain="fake", - entity_id=f"fake.{entity_id}", - object_id=entity_id, - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": "state", - "tags": {"domain": "fake", "entity_id": entity_id}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - if entity_id == "ok": - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call( - body - ) - else: - assert not mock_client.return_value.write_points.called - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + if entity_id == "included": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() - def test_event_listener_unit_of_measurement_field(self, mock_client): - """Test the event listener for unit of measurement field.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "override_measurement": "state", + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_whitelist_domain( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a domain whitelist.""" + config = {"include": {"domains": ["fake"]}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + for domain in ("fake", "another_fake"): + state = MagicMock( + state=1, + domain=domain, + entity_id=f"{domain}.something", + object_id="something", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"{domain}.something", + "tags": {"domain": domain, "entity_id": "something"}, + "time": 12345, + "fields": {"value": 1}, } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - attrs = {"unit_of_measurement": "foobars"} - state = mock.MagicMock( - state="foo", + write_api = get_write_api(mock_client) + if domain == "fake": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_whitelist_domain_and_entities( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener against a domain and entity whitelist.""" + config = {"include": {"domains": ["fake"], "entities": ["other.one"]}} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + for domain in ("fake", "another_fake"): + state = MagicMock( + state=1, + domain=domain, + entity_id=f"{domain}.something", + object_id="something", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"{domain}.something", + "tags": {"domain": domain, "entity_id": "something"}, + "time": 12345, + "fields": {"value": 1}, + } + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + if domain == "fake": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() + + for entity_id in ("one", "two"): + state = MagicMock( + state=1, + domain="other", + entity_id=f"other.{entity_id}", + object_id=entity_id, + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": f"other.{entity_id}", + "tags": {"domain": "other", "entity_id": entity_id}, + "time": 12345, + "fields": {"value": 1}, + } + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + if entity_id == "one": + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + else: + assert not write_api.called + write_api.reset_mock() + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_invalid_type( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener when an attribute has an invalid type.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) + + # map of HA State to valid influxdb [state, value] fields + valid = { + "1": [None, 1], + "1.0": [None, 1.0], + STATE_ON: [STATE_ON, 1], + STATE_OFF: [STATE_OFF, 0], + STATE_STANDBY: [STATE_STANDBY, None], + "foo": ["foo", None], + } + for in_, out in valid.items(): + attrs = { + "unit_of_measurement": "foobars", + "longitude": "1.1", + "latitude": "2.2", + "invalid_attribute": ["value1", "value2"], + } + state = MagicMock( + state=in_, domain="fake", entity_id="fake.entity-id", object_id="entity", attributes=attrs, ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) + event = MagicMock(data={"new_state": state}, time_fired=12345) body = [ { - "measurement": "state", + "measurement": "foobars", "tags": {"domain": "fake", "entity_id": "entity"}, "time": 12345, - "fields": {"state": "foo", "unit_of_measurement_str": "foobars"}, + "fields": { + "longitude": 1.1, + "latitude": 2.2, + "invalid_attribute_str": "['value1', 'value2']", + }, } ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() + if out[0] is not None: + body[0]["fields"]["state"] = out[0] + if out[1] is not None: + body[0]["fields"]["value"] = out[1] - def test_event_listener_tags_attributes(self, mock_client): - """Test the event listener when some attributes should be tags.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "tags_attributes": ["friendly_fake"], - } + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + write_api.reset_mock() + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_default_measurement( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener with a default measurement.""" + config = {"default_measurement": "state"} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + state = MagicMock( + state=1, domain="fake", entity_id="fake.ok", object_id="ok", attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "state", + "tags": {"domain": "fake", "entity_id": "ok"}, + "time": 12345, + "fields": {"value": 1}, } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - attrs = {"friendly_fake": "tag_str", "field_fake": "field_str"} - state = mock.MagicMock( + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_unit_of_measurement_field( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener for unit of measurement field.""" + config = {"override_measurement": "state"} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + attrs = {"unit_of_measurement": "foobars"} + state = MagicMock( + state="foo", + domain="fake", + entity_id="fake.entity-id", + object_id="entity", + attributes=attrs, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "state", + "tags": {"domain": "fake", "entity_id": "entity"}, + "time": 12345, + "fields": {"state": "foo", "unit_of_measurement_str": "foobars"}, + } + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_tags_attributes( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener when some attributes should be tags.""" + config = {"tags_attributes": ["friendly_fake"]} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + attrs = {"friendly_fake": "tag_str", "field_fake": "field_str"} + state = MagicMock( + state=1, + domain="fake", + entity_id="fake.something", + object_id="something", + attributes=attrs, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": "fake.something", + "tags": { + "domain": "fake", + "entity_id": "something", + "friendly_fake": "tag_str", + }, + "time": 12345, + "fields": {"value": 1, "field_fake_str": "field_str"}, + } + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + + +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_component_override_measurement( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener with overridden measurements.""" + config = { + "component_config": { + "sensor.fake_humidity": {"override_measurement": "humidity"} + }, + "component_config_glob": { + "binary_sensor.*motion": {"override_measurement": "motion"} + }, + "component_config_domain": {"climate": {"override_measurement": "hvac"}}, + } + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + test_components = [ + {"domain": "sensor", "id": "fake_humidity", "res": "humidity"}, + {"domain": "binary_sensor", "id": "fake_motion", "res": "motion"}, + {"domain": "climate", "id": "fake_thermostat", "res": "hvac"}, + {"domain": "other", "id": "just_fake", "res": "other.just_fake"}, + ] + for comp in test_components: + state = MagicMock( state=1, - domain="fake", - entity_id="fake.something", - object_id="something", - attributes=attrs, + domain=comp["domain"], + entity_id=f"{comp['domain']}.{comp['id']}", + object_id=comp["id"], + attributes={}, ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) + event = MagicMock(data={"new_state": state}, time_fired=12345) body = [ { - "measurement": "fake.something", - "tags": { - "domain": "fake", - "entity_id": "something", - "friendly_fake": "tag_str", - }, + "measurement": comp["res"], + "tags": {"domain": comp["domain"], "entity_id": comp["id"]}, "time": 12345, - "fields": {"value": 1, "field_fake_str": "field_str"}, + "fields": {"value": 1}, } ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - def test_event_listener_component_override_measurement(self, mock_client): - """Test the event listener with overridden measurements.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "component_config": { - "sensor.fake_humidity": {"override_measurement": "humidity"} - }, - "component_config_glob": { - "binary_sensor.*motion": {"override_measurement": "motion"} - }, - "component_config_domain": { - "climate": {"override_measurement": "hvac"} - }, - } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + write_api.reset_mock() - test_components = [ - {"domain": "sensor", "id": "fake_humidity", "res": "humidity"}, - {"domain": "binary_sensor", "id": "fake_motion", "res": "motion"}, - {"domain": "climate", "id": "fake_thermostat", "res": "hvac"}, - {"domain": "other", "id": "just_fake", "res": "other.just_fake"}, - ] - for comp in test_components: - state = mock.MagicMock( - state=1, - domain=comp["domain"], - entity_id=f"{comp['domain']}.{comp['id']}", - object_id=comp["id"], - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - body = [ - { - "measurement": comp["res"], - "tags": {"domain": comp["domain"], "entity_id": comp["id"]}, - "time": 12345, - "fields": {"value": 1}, - } - ] - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_client.return_value.write_points.call_count == 1 - assert mock_client.return_value.write_points.call_args == mock.call(body) - mock_client.return_value.write_points.reset_mock() - def test_scheduled_write(self, mock_client): - """Test the event listener to retry after write failures.""" - config = { - "influxdb": { - "host": "host", - "username": "user", - "password": "pass", - "max_retries": 1, - } - } - assert setup_component(self.hass, influxdb.DOMAIN, config) - self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] - mock_client.return_value.write_points.reset_mock() +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_scheduled_write( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener retries after a write failure.""" + config = {"max_retries": 1} + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) - state = mock.MagicMock( - state=1, - domain="fake", - entity_id="entity.id", - object_id="entity", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) - mock_client.return_value.write_points.side_effect = IOError("foo") + state = MagicMock( + state=1, + domain="fake", + entity_id="entity.id", + object_id="entity", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + write_api = get_write_api(mock_client) + write_api.side_effect = IOError("foo") - # Write fails - with mock.patch.object(influxdb.time, "sleep") as mock_sleep: - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert mock_sleep.called - json_data = mock_client.return_value.write_points.call_args[0][0] - assert mock_client.return_value.write_points.call_count == 2 - mock_client.return_value.write_points.assert_called_with(json_data) + # Write fails + with patch.object(influxdb.time, "sleep") as mock_sleep: + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + assert mock_sleep.called + assert write_api.call_count == 2 - # Write works again - mock_client.return_value.write_points.side_effect = None - with mock.patch.object(influxdb.time, "sleep") as mock_sleep: - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() - assert not mock_sleep.called - assert mock_client.return_value.write_points.call_count == 3 + # Write works again + write_api.side_effect = None + with patch.object(influxdb.time, "sleep") as mock_sleep: + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + assert not mock_sleep.called + assert write_api.call_count == 3 - def test_queue_backlog_full(self, mock_client): - """Test the event listener to drop old events.""" - self._setup(mock_client) - state = mock.MagicMock( - state=1, - domain="fake", - entity_id="entity.id", - object_id="entity", - attributes={}, - ) - event = mock.MagicMock(data={"new_state": state}, time_fired=12345) +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_backlog_full( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener drops old events when backlog gets full.""" + handler_method = await _setup(hass, mock_client, config_ext, get_write_api) - monotonic_time = 0 + state = MagicMock( + state=1, + domain="fake", + entity_id="entity.id", + object_id="entity", + attributes={}, + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) - def fast_monotonic(): - """Monotonic time that ticks fast enough to cause a timeout.""" - nonlocal monotonic_time - monotonic_time += 60 - return monotonic_time + monotonic_time = 0 - with mock.patch( - "homeassistant.components.influxdb.time.monotonic", new=fast_monotonic - ): - self.handler_method(event) - self.hass.data[influxdb.DOMAIN].block_till_done() + def fast_monotonic(): + """Monotonic time that ticks fast enough to cause a timeout.""" + nonlocal monotonic_time + monotonic_time += 60 + return monotonic_time - assert mock_client.return_value.write_points.call_count == 0 + with patch("homeassistant.components.influxdb.time.monotonic", new=fast_monotonic): + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() - mock_client.return_value.write_points.reset_mock() + assert get_write_api(mock_client).call_count == 0