Support publishing MQTT messages with raw bytes payloads (#61090)

* correctly publish mqtt ouput

* Additional tests

* Add template test with binary output

* render_outgoing_payload with command templates

* use MqttCommandTemplate helper class

* add tests command_template

* Additional tests

* support pass-through for MqttComandTemplate

* fix bugs

* unify workform always initiate with hass

* clean up

* remove not needed lines

* comment not adding value
pull/61875/head
Jan Bouwhuis 2021-12-15 11:28:43 +01:00 committed by GitHub
parent a1abcad0ca
commit d5defa8995
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 209 additions and 79 deletions

View File

@ -1,6 +1,7 @@
"""Support for MQTT message handling."""
from __future__ import annotations
from ast import literal_eval
import asyncio
from dataclasses import dataclass
import datetime as dt
@ -250,6 +251,55 @@ MQTT_PUBLISH_SCHEMA = vol.All(
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
class MqttCommandTemplate:
"""Class for rendering MQTT payload with command templates."""
def __init__(
self,
command_template: template.Template | None,
hass: HomeAssistant,
) -> None:
"""Instantiate a command template."""
self._attr_command_template = command_template
if command_template is None:
return
command_template.hass = hass
@callback
def async_render(
self,
value: PublishPayloadType = None,
variables: template.TemplateVarsType = None,
) -> PublishPayloadType:
"""Render or convert the command template with given value or variables."""
def _convert_outgoing_payload(
payload: PublishPayloadType,
) -> PublishPayloadType:
"""Ensure correct raw MQTT payload is passed as bytes for publishing."""
if isinstance(payload, str):
try:
native_object = literal_eval(payload)
if isinstance(native_object, bytes):
return native_object
except (ValueError, TypeError, SyntaxError, MemoryError):
pass
return payload
if self._attr_command_template is None:
return value
values = {"value": value}
if variables is not None:
values.update(variables)
return _convert_outgoing_payload(
self._attr_command_template.async_render(values, parse_result=False)
)
@dataclass
class MqttServiceInfo(BaseServiceInfo):
"""Prepared info from mqtt entries."""
@ -295,7 +345,9 @@ async def async_publish(
hass: HomeAssistant, topic: Any, payload, qos=0, retain=False
) -> None:
"""Publish message to an MQTT topic."""
await hass.data[DATA_MQTT].async_publish(topic, str(payload), qos, retain)
await hass.data[DATA_MQTT].async_publish(
topic, str(payload) if not isinstance(payload, bytes) else payload, qos, retain
)
AsyncDeprecatedMessageCallbackType = Callable[
@ -523,9 +575,9 @@ async def async_setup_entry(hass, entry):
if payload_template is not None:
try:
payload = template.Template(payload_template, hass).async_render(
parse_result=False
)
payload = MqttCommandTemplate(
template.Template(payload_template), hass
).async_render()
except (template.jinja2.TemplateError, TemplateError) as exc:
_LOGGER.error(
"Unable to publish to %s: rendering payload template of "

View File

@ -34,7 +34,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -150,8 +150,9 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
value_template = self._config.get(CONF_VALUE_TEMPLATE)
if value_template is not None:
value_template.hass = self.hass
command_template = self._config[CONF_COMMAND_TEMPLATE]
command_template.hass = self.hass
self._command_template = MqttCommandTemplate(
self._config[CONF_COMMAND_TEMPLATE], self.hass
).async_render
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
@ -306,9 +307,8 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
async def _publish(self, code, action):
"""Publish via mqtt."""
command_template = self._config[CONF_COMMAND_TEMPLATE]
values = {"action": action, "code": code}
payload = command_template.async_render(**values, parse_result=False)
variables = {"action": action, "code": code}
payload = self._command_template(None, variables=variables)
await mqtt.async_publish(
self.hass,
self._config[CONF_COMMAND_TOPIC],

View File

@ -52,7 +52,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, subscription
from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_QOS, CONF_RETAIN, DOMAIN
from .debug_info import log_messages
@ -377,11 +377,10 @@ class MqttClimate(MqttEntity, ClimateEntity):
command_templates = {}
for key in COMMAND_TEMPLATE_KEYS:
command_templates[key] = lambda value: value
for key in COMMAND_TEMPLATE_KEYS & config.keys():
tpl = config[key]
command_templates[key] = tpl.async_render_with_possible_json_value
tpl.hass = self.hass
command_templates[key] = MqttCommandTemplate(
config.get(key), self.hass
).async_render
self._command_templates = command_templates
async def _subscribe_topics(self): # noqa: C901

View File

@ -36,7 +36,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -288,17 +288,17 @@ class MqttCover(MqttEntity, CoverEntity):
if value_template is not None:
value_template.hass = self.hass
set_position_template = self._config.get(CONF_SET_POSITION_TEMPLATE)
if set_position_template is not None:
set_position_template.hass = self.hass
self._set_position_template = MqttCommandTemplate(
self._config.get(CONF_SET_POSITION_TEMPLATE), self.hass
).async_render
get_position_template = self._config.get(CONF_GET_POSITION_TEMPLATE)
if get_position_template is not None:
get_position_template.hass = self.hass
set_tilt_template = self._config.get(CONF_TILT_COMMAND_TEMPLATE)
if set_tilt_template is not None:
set_tilt_template.hass = self.hass
self._set_tilt_template = MqttCommandTemplate(
self._config.get(CONF_TILT_COMMAND_TEMPLATE), self.hass
).async_render
tilt_status_template = self._config.get(CONF_TILT_STATUS_TEMPLATE)
if tilt_status_template is not None:
@ -611,21 +611,19 @@ class MqttCover(MqttEntity, CoverEntity):
async def async_set_cover_tilt_position(self, **kwargs):
"""Move the cover tilt to a specific position."""
template = self._config.get(CONF_TILT_COMMAND_TEMPLATE)
tilt = kwargs[ATTR_TILT_POSITION]
percentage_tilt = tilt
tilt = self.find_in_range_from_percent(tilt)
# Handover the tilt after calculated from percent would make it more consistent with receiving templates
if template is not None:
variables = {
"tilt_position": percentage_tilt,
"entity_id": self.entity_id,
"position_open": self._config[CONF_POSITION_OPEN],
"position_closed": self._config[CONF_POSITION_CLOSED],
"tilt_min": self._config[CONF_TILT_MIN],
"tilt_max": self._config[CONF_TILT_MAX],
}
tilt = template.async_render(parse_result=False, variables=variables)
variables = {
"tilt_position": percentage_tilt,
"entity_id": self.entity_id,
"position_open": self._config.get(CONF_POSITION_OPEN),
"position_closed": self._config.get(CONF_POSITION_CLOSED),
"tilt_min": self._config.get(CONF_TILT_MIN),
"tilt_max": self._config.get(CONF_TILT_MAX),
}
tilt = self._set_tilt_template(tilt, variables=variables)
await mqtt.async_publish(
self.hass,
@ -641,20 +639,18 @@ class MqttCover(MqttEntity, CoverEntity):
async def async_set_cover_position(self, **kwargs):
"""Move the cover to a specific position."""
template = self._config.get(CONF_SET_POSITION_TEMPLATE)
position = kwargs[ATTR_POSITION]
percentage_position = position
position = self.find_in_range_from_percent(position, COVER_PAYLOAD)
if template is not None:
variables = {
"position": percentage_position,
"entity_id": self.entity_id,
"position_open": self._config[CONF_POSITION_OPEN],
"position_closed": self._config[CONF_POSITION_CLOSED],
"tilt_min": self._config[CONF_TILT_MIN],
"tilt_max": self._config[CONF_TILT_MAX],
}
position = template.async_render(parse_result=False, variables=variables)
variables = {
"position": percentage_position,
"entity_id": self.entity_id,
"position_open": self._config[CONF_POSITION_OPEN],
"position_closed": self._config[CONF_POSITION_CLOSED],
"tilt_min": self._config[CONF_TILT_MIN],
"tilt_max": self._config[CONF_TILT_MAX],
}
position = self._set_position_template(position, variables=variables)
await mqtt.async_publish(
self.hass,

View File

@ -36,7 +36,7 @@ from homeassistant.util.percentage import (
ranged_value_to_percentage,
)
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -332,13 +332,17 @@ class MqttFan(MqttEntity, FanEntity):
if self._feature_preset_mode:
self._supported_features |= SUPPORT_PRESET_MODE
for tpl_dict in (self._command_templates, self._value_templates):
for key, tpl in tpl_dict.items():
if tpl is None:
tpl_dict[key] = lambda value: value
else:
tpl.hass = self.hass
tpl_dict[key] = tpl.async_render_with_possible_json_value
for key, tpl in self._command_templates.items():
self._command_templates[key] = MqttCommandTemplate(
tpl, self.hass
).async_render
for key, tpl in self._value_templates.items():
if tpl is None:
self._value_templates[key] = lambda value: value
else:
tpl.hass = self.hass
self._value_templates[key] = tpl.async_render_with_possible_json_value
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -27,7 +27,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -237,13 +237,17 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
)
self._optimistic_mode = optimistic or self._topic[CONF_MODE_STATE_TOPIC] is None
for tpl_dict in (self._command_templates, self._value_templates):
for key, tpl in tpl_dict.items():
if tpl is None:
tpl_dict[key] = lambda value: value
else:
tpl.hass = self.hass
tpl_dict[key] = tpl.async_render_with_possible_json_value
for key, tpl in self._command_templates.items():
self._command_templates[key] = MqttCommandTemplate(
tpl, self.hass
).async_render
for key, tpl in self._value_templates.items():
if tpl is None:
self._value_templates[key] = lambda value: value
else:
tpl.hass = self.hass
self._value_templates[key] = tpl.async_render_with_possible_json_value
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -25,7 +25,7 @@ from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -138,15 +138,20 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
self._optimistic = config[CONF_OPTIMISTIC]
self._templates = {
CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE),
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
config.get(CONF_COMMAND_TEMPLATE), self.hass
).async_render,
CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE),
}
for key, tpl in self._templates.items():
if tpl is None:
self._templates[key] = lambda value: value
else:
tpl.hass = self.hass
self._templates[key] = tpl.async_render_with_possible_json_value
value_template = self._templates[CONF_VALUE_TEMPLATE]
if value_template is None:
self._templates[CONF_VALUE_TEMPLATE] = lambda value: value
else:
value_template.hass = self.hass
self._templates[
CONF_VALUE_TEMPLATE
] = value_template.async_render_with_possible_json_value
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -13,7 +13,7 @@ from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType
from . import PLATFORMS, subscription
from . import PLATFORMS, MqttCommandTemplate, subscription
from .. import mqtt
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
from .debug_info import log_messages
@ -102,15 +102,20 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
self._attr_options = config[CONF_OPTIONS]
self._templates = {
CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE),
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
config.get(CONF_COMMAND_TEMPLATE), self.hass
).async_render,
CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE),
}
for key, tpl in self._templates.items():
if tpl is None:
self._templates[key] = lambda value: value
else:
tpl.hass = self.hass
self._templates[key] = tpl.async_render_with_possible_json_value
value_template = self._templates[CONF_VALUE_TEMPLATE]
if value_template is None:
self._templates[CONF_VALUE_TEMPLATE] = lambda value: value
else:
value_template.hass = self.hass
self._templates[
CONF_VALUE_TEMPLATE
] = value_template.async_render_with_possible_json_value
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""

View File

@ -18,7 +18,7 @@ from homeassistant.const import (
)
from homeassistant.core import CoreState, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers import device_registry as dr, template
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
@ -91,7 +91,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(hass, mqtt_mock):
assert mqtt_mock.async_disconnect.called
async def test_publish_(hass, mqtt_mock):
async def test_publish(hass, mqtt_mock):
"""Test the publish function."""
await mqtt.async_publish(hass, "test-topic", "test-payload")
await hass.async_block_till_done()
@ -137,6 +137,57 @@ async def test_publish_(hass, mqtt_mock):
)
mqtt_mock.reset_mock()
# test binary pass-through
mqtt.publish(
hass,
"test-topic3",
b"\xde\xad\xbe\xef",
0,
False,
)
await hass.async_block_till_done()
assert mqtt_mock.async_publish.called
assert mqtt_mock.async_publish.call_args[0] == (
"test-topic3",
b"\xde\xad\xbe\xef",
0,
False,
)
mqtt_mock.reset_mock()
async def test_convert_outgoing_payload(hass):
"""Test the converting of outgoing MQTT payloads without template."""
command_template = mqtt.MqttCommandTemplate(None, hass)
assert command_template.async_render(b"\xde\xad\xbe\xef") == b"\xde\xad\xbe\xef"
assert (
command_template.async_render("b'\\xde\\xad\\xbe\\xef'")
== "b'\\xde\\xad\\xbe\\xef'"
)
assert command_template.async_render(1234) == 1234
assert command_template.async_render(1234.56) == 1234.56
assert command_template.async_render(None) is None
async def test_command_template_value(hass):
"""Test the rendering of MQTT command template."""
variables = {"id": 1234, "some_var": "beer"}
# test rendering value
tpl = template.Template("{{ value + 1 }}", hass)
cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass)
assert cmd_tpl.async_render(4321) == "4322"
# test variables at rendering
tpl = template.Template("{{ some_var }}", hass)
cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass)
assert cmd_tpl.async_render(None, variables=variables) == "beer"
async def test_service_call_without_topic_does_not_publish(hass, mqtt_mock):
"""Test the service call if topic is missing."""
@ -260,6 +311,20 @@ async def test_service_call_with_template_payload_renders_template(hass, mqtt_mo
)
assert mqtt_mock.async_publish.called
assert mqtt_mock.async_publish.call_args[0][1] == "8"
mqtt_mock.reset_mock()
await hass.services.async_call(
mqtt.DOMAIN,
mqtt.SERVICE_PUBLISH,
{
mqtt.ATTR_TOPIC: "test/topic",
mqtt.ATTR_PAYLOAD_TEMPLATE: "{{ (4+4) | pack('B') }}",
},
blocking=True,
)
assert mqtt_mock.async_publish.called
assert mqtt_mock.async_publish.call_args[0][1] == b"\x08"
mqtt_mock.reset_mock()
async def test_service_call_with_bad_template(hass, mqtt_mock):