diff --git a/homeassistant/components/influxdb/__init__.py b/homeassistant/components/influxdb/__init__.py index 45d3a4f5a25..057d0657685 100644 --- a/homeassistant/components/influxdb/__init__.py +++ b/homeassistant/components/influxdb/__init__.py @@ -52,6 +52,7 @@ from .const import ( CONF_DEFAULT_MEASUREMENT, CONF_HOST, CONF_IGNORE_ATTRIBUTES, + CONF_MEASUREMENT_ATTR, CONF_ORG, CONF_OVERRIDE_MEASUREMENT, CONF_PASSWORD, @@ -68,6 +69,7 @@ from .const import ( CONNECTION_ERROR, DEFAULT_API_VERSION, DEFAULT_HOST_V2, + DEFAULT_MEASUREMENT_ATTR, DEFAULT_SSL_V2, DOMAIN, EVENT_NEW_STATE, @@ -154,6 +156,9 @@ _INFLUX_BASE_SCHEMA = INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend( { vol.Optional(CONF_RETRY_COUNT, default=0): cv.positive_int, vol.Optional(CONF_DEFAULT_MEASUREMENT): cv.string, + vol.Optional(CONF_MEASUREMENT_ATTR, default=DEFAULT_MEASUREMENT_ATTR): vol.In( + ["unit_of_measurement", "domain__device_class", "entity_id"] + ), vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, vol.Optional(CONF_TAGS, default={}): vol.Schema({cv.string: cv.string}), vol.Optional(CONF_TAGS_ATTRIBUTES, default=[]): vol.All( @@ -192,6 +197,7 @@ def _generate_event_to_json(conf: Dict) -> Callable[[Dict], str]: tags = conf.get(CONF_TAGS) tags_attributes = conf.get(CONF_TAGS_ATTRIBUTES) default_measurement = conf.get(CONF_DEFAULT_MEASUREMENT) + measurement_attr = conf.get(CONF_MEASUREMENT_ATTR) override_measurement = conf.get(CONF_OVERRIDE_MEASUREMENT) global_ignore_attributes = set(conf[CONF_IGNORE_ATTRIBUTES]) component_config = EntityValues( @@ -223,20 +229,32 @@ def _generate_event_to_json(conf: Dict) -> Callable[[Dict], str]: _include_state = True include_uom = True + include_dc = True entity_config = component_config.get(state.entity_id) measurement = entity_config.get(CONF_OVERRIDE_MEASUREMENT) if measurement in (None, ""): if override_measurement: measurement = override_measurement else: - measurement = state.attributes.get(CONF_UNIT_OF_MEASUREMENT) + if measurement_attr == "entity_id": + measurement = state.entity_id + elif measurement_attr == "domain__device_class": + device_class = state.attributes.get("device_class") + if device_class is None: + # This entity doesn't have a device_class set, use only domain + measurement = state.domain + else: + measurement = f"{state.domain}__{device_class}" + include_dc = False + else: + measurement = state.attributes.get(measurement_attr) if measurement in (None, ""): if default_measurement: measurement = default_measurement else: measurement = state.entity_id else: - include_uom = False + include_uom = measurement_attr != "unit_of_measurement" json = { INFLUX_CONF_MEASUREMENT: measurement, @@ -258,8 +276,10 @@ def _generate_event_to_json(conf: Dict) -> Callable[[Dict], str]: if key in tags_attributes: json[INFLUX_CONF_TAGS][key] = value elif ( - key != CONF_UNIT_OF_MEASUREMENT or include_uom - ) and key not in ignore_attributes: + (key != CONF_UNIT_OF_MEASUREMENT or include_uom) + and (key != "device_class" or include_dc) + and key not in ignore_attributes + ): # If the key is already in fields if key in json[INFLUX_CONF_FIELDS]: key = f"{key}_" diff --git a/homeassistant/components/influxdb/const.py b/homeassistant/components/influxdb/const.py index 029e4d482e8..1a827c1b63c 100644 --- a/homeassistant/components/influxdb/const.py +++ b/homeassistant/components/influxdb/const.py @@ -22,6 +22,7 @@ CONF_BUCKET = "bucket" CONF_ORG = "organization" CONF_TAGS = "tags" CONF_DEFAULT_MEASUREMENT = "default_measurement" +CONF_MEASUREMENT_ATTR = "measurement_attr" CONF_OVERRIDE_MEASUREMENT = "override_measurement" CONF_TAGS_ATTRIBUTES = "tags_attributes" CONF_COMPONENT_CONFIG = "component_config" @@ -56,6 +57,7 @@ DEFAULT_FIELD = "value" DEFAULT_RANGE_START = "-15m" DEFAULT_RANGE_STOP = "now()" DEFAULT_FUNCTION_FLUX = "|> limit(n: 1)" +DEFAULT_MEASUREMENT_ATTR = "unit_of_measurement" INFLUX_CONF_MEASUREMENT = "measurement" INFLUX_CONF_TAGS = "tags" diff --git a/tests/components/influxdb/test_init.py b/tests/components/influxdb/test_init.py index edb85e7b98d..06ec3725195 100644 --- a/tests/components/influxdb/test_init.py +++ b/tests/components/influxdb/test_init.py @@ -1066,6 +1066,79 @@ async def test_event_listener_component_override_measurement( write_api.reset_mock() +@pytest.mark.parametrize( + "mock_client, config_ext, get_write_api, get_mock_call", + [ + ( + influxdb.DEFAULT_API_VERSION, + BASE_V1_CONFIG, + _get_write_api_mock_v1, + influxdb.DEFAULT_API_VERSION, + ), + ( + influxdb.API_VERSION_2, + BASE_V2_CONFIG, + _get_write_api_mock_v2, + influxdb.API_VERSION_2, + ), + ], + indirect=["mock_client", "get_mock_call"], +) +async def test_event_listener_component_measurement_attr( + hass, mock_client, config_ext, get_write_api, get_mock_call +): + """Test the event listener with a different measurement_attr.""" + config = { + "measurement_attr": "domain__device_class", + "component_config": { + "sensor.fake_humidity": {"override_measurement": "humidity"} + }, + "component_config_glob": { + "binary_sensor.*motion": {"override_measurement": "motion"} + }, + "component_config_domain": {"climate": {"override_measurement": "hvac"}}, + } + config.update(config_ext) + handler_method = await _setup(hass, mock_client, config, get_write_api) + + test_components = [ + { + "domain": "sensor", + "id": "fake_temperature", + "attrs": {"device_class": "humidity"}, + "res": "sensor__humidity", + }, + {"domain": "sensor", "id": "fake_humidity", "attrs": {}, "res": "humidity"}, + {"domain": "binary_sensor", "id": "fake_motion", "attrs": {}, "res": "motion"}, + {"domain": "climate", "id": "fake_thermostat", "attrs": {}, "res": "hvac"}, + {"domain": "other", "id": "just_fake", "attrs": {}, "res": "other"}, + ] + for comp in test_components: + state = MagicMock( + state=1, + domain=comp["domain"], + entity_id=f"{comp['domain']}.{comp['id']}", + object_id=comp["id"], + attributes=comp["attrs"], + ) + event = MagicMock(data={"new_state": state}, time_fired=12345) + body = [ + { + "measurement": comp["res"], + "tags": {"domain": comp["domain"], "entity_id": comp["id"]}, + "time": 12345, + "fields": {"value": 1}, + } + ] + handler_method(event) + hass.data[influxdb.DOMAIN].block_till_done() + + write_api = get_write_api(mock_client) + assert write_api.call_count == 1 + assert write_api.call_args == get_mock_call(body) + write_api.reset_mock() + + @pytest.mark.parametrize( "mock_client, config_ext, get_write_api, get_mock_call", [