Support publishing MQTT messages with raw bytes payloads (#61090)
* correctly publish mqtt ouput * Additional tests * Add template test with binary output * render_outgoing_payload with command templates * use MqttCommandTemplate helper class * add tests command_template * Additional tests * support pass-through for MqttComandTemplate * fix bugs * unify workform always initiate with hass * clean up * remove not needed lines * comment not adding valuepull/61875/head
parent
a1abcad0ca
commit
d5defa8995
|
@ -1,6 +1,7 @@
|
|||
"""Support for MQTT message handling."""
|
||||
from __future__ import annotations
|
||||
|
||||
from ast import literal_eval
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import datetime as dt
|
||||
|
@ -250,6 +251,55 @@ MQTT_PUBLISH_SCHEMA = vol.All(
|
|||
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
|
||||
|
||||
|
||||
class MqttCommandTemplate:
|
||||
"""Class for rendering MQTT payload with command templates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command_template: template.Template | None,
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Instantiate a command template."""
|
||||
self._attr_command_template = command_template
|
||||
if command_template is None:
|
||||
return
|
||||
|
||||
command_template.hass = hass
|
||||
|
||||
@callback
|
||||
def async_render(
|
||||
self,
|
||||
value: PublishPayloadType = None,
|
||||
variables: template.TemplateVarsType = None,
|
||||
) -> PublishPayloadType:
|
||||
"""Render or convert the command template with given value or variables."""
|
||||
|
||||
def _convert_outgoing_payload(
|
||||
payload: PublishPayloadType,
|
||||
) -> PublishPayloadType:
|
||||
"""Ensure correct raw MQTT payload is passed as bytes for publishing."""
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
native_object = literal_eval(payload)
|
||||
if isinstance(native_object, bytes):
|
||||
return native_object
|
||||
|
||||
except (ValueError, TypeError, SyntaxError, MemoryError):
|
||||
pass
|
||||
|
||||
return payload
|
||||
|
||||
if self._attr_command_template is None:
|
||||
return value
|
||||
|
||||
values = {"value": value}
|
||||
if variables is not None:
|
||||
values.update(variables)
|
||||
return _convert_outgoing_payload(
|
||||
self._attr_command_template.async_render(values, parse_result=False)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MqttServiceInfo(BaseServiceInfo):
|
||||
"""Prepared info from mqtt entries."""
|
||||
|
@ -295,7 +345,9 @@ async def async_publish(
|
|||
hass: HomeAssistant, topic: Any, payload, qos=0, retain=False
|
||||
) -> None:
|
||||
"""Publish message to an MQTT topic."""
|
||||
await hass.data[DATA_MQTT].async_publish(topic, str(payload), qos, retain)
|
||||
await hass.data[DATA_MQTT].async_publish(
|
||||
topic, str(payload) if not isinstance(payload, bytes) else payload, qos, retain
|
||||
)
|
||||
|
||||
|
||||
AsyncDeprecatedMessageCallbackType = Callable[
|
||||
|
@ -523,9 +575,9 @@ async def async_setup_entry(hass, entry):
|
|||
|
||||
if payload_template is not None:
|
||||
try:
|
||||
payload = template.Template(payload_template, hass).async_render(
|
||||
parse_result=False
|
||||
)
|
||||
payload = MqttCommandTemplate(
|
||||
template.Template(payload_template), hass
|
||||
).async_render()
|
||||
except (template.jinja2.TemplateError, TemplateError) as exc:
|
||||
_LOGGER.error(
|
||||
"Unable to publish to %s: rendering payload template of "
|
||||
|
|
|
@ -34,7 +34,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -150,8 +150,9 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
|||
value_template = self._config.get(CONF_VALUE_TEMPLATE)
|
||||
if value_template is not None:
|
||||
value_template.hass = self.hass
|
||||
command_template = self._config[CONF_COMMAND_TEMPLATE]
|
||||
command_template.hass = self.hass
|
||||
self._command_template = MqttCommandTemplate(
|
||||
self._config[CONF_COMMAND_TEMPLATE], self.hass
|
||||
).async_render
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
@ -306,9 +307,8 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
|||
|
||||
async def _publish(self, code, action):
|
||||
"""Publish via mqtt."""
|
||||
command_template = self._config[CONF_COMMAND_TEMPLATE]
|
||||
values = {"action": action, "code": code}
|
||||
payload = command_template.async_render(**values, parse_result=False)
|
||||
variables = {"action": action, "code": code}
|
||||
payload = self._command_template(None, variables=variables)
|
||||
await mqtt.async_publish(
|
||||
self.hass,
|
||||
self._config[CONF_COMMAND_TOPIC],
|
||||
|
|
|
@ -52,7 +52,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, subscription
|
||||
from . import MQTT_BASE_PLATFORM_SCHEMA, PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_QOS, CONF_RETAIN, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -377,11 +377,10 @@ class MqttClimate(MqttEntity, ClimateEntity):
|
|||
|
||||
command_templates = {}
|
||||
for key in COMMAND_TEMPLATE_KEYS:
|
||||
command_templates[key] = lambda value: value
|
||||
for key in COMMAND_TEMPLATE_KEYS & config.keys():
|
||||
tpl = config[key]
|
||||
command_templates[key] = tpl.async_render_with_possible_json_value
|
||||
tpl.hass = self.hass
|
||||
command_templates[key] = MqttCommandTemplate(
|
||||
config.get(key), self.hass
|
||||
).async_render
|
||||
|
||||
self._command_templates = command_templates
|
||||
|
||||
async def _subscribe_topics(self): # noqa: C901
|
||||
|
|
|
@ -36,7 +36,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -288,17 +288,17 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
if value_template is not None:
|
||||
value_template.hass = self.hass
|
||||
|
||||
set_position_template = self._config.get(CONF_SET_POSITION_TEMPLATE)
|
||||
if set_position_template is not None:
|
||||
set_position_template.hass = self.hass
|
||||
self._set_position_template = MqttCommandTemplate(
|
||||
self._config.get(CONF_SET_POSITION_TEMPLATE), self.hass
|
||||
).async_render
|
||||
|
||||
get_position_template = self._config.get(CONF_GET_POSITION_TEMPLATE)
|
||||
if get_position_template is not None:
|
||||
get_position_template.hass = self.hass
|
||||
|
||||
set_tilt_template = self._config.get(CONF_TILT_COMMAND_TEMPLATE)
|
||||
if set_tilt_template is not None:
|
||||
set_tilt_template.hass = self.hass
|
||||
self._set_tilt_template = MqttCommandTemplate(
|
||||
self._config.get(CONF_TILT_COMMAND_TEMPLATE), self.hass
|
||||
).async_render
|
||||
|
||||
tilt_status_template = self._config.get(CONF_TILT_STATUS_TEMPLATE)
|
||||
if tilt_status_template is not None:
|
||||
|
@ -611,21 +611,19 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
|
||||
async def async_set_cover_tilt_position(self, **kwargs):
|
||||
"""Move the cover tilt to a specific position."""
|
||||
template = self._config.get(CONF_TILT_COMMAND_TEMPLATE)
|
||||
tilt = kwargs[ATTR_TILT_POSITION]
|
||||
percentage_tilt = tilt
|
||||
tilt = self.find_in_range_from_percent(tilt)
|
||||
# Handover the tilt after calculated from percent would make it more consistent with receiving templates
|
||||
if template is not None:
|
||||
variables = {
|
||||
"tilt_position": percentage_tilt,
|
||||
"entity_id": self.entity_id,
|
||||
"position_open": self._config[CONF_POSITION_OPEN],
|
||||
"position_closed": self._config[CONF_POSITION_CLOSED],
|
||||
"tilt_min": self._config[CONF_TILT_MIN],
|
||||
"tilt_max": self._config[CONF_TILT_MAX],
|
||||
}
|
||||
tilt = template.async_render(parse_result=False, variables=variables)
|
||||
variables = {
|
||||
"tilt_position": percentage_tilt,
|
||||
"entity_id": self.entity_id,
|
||||
"position_open": self._config.get(CONF_POSITION_OPEN),
|
||||
"position_closed": self._config.get(CONF_POSITION_CLOSED),
|
||||
"tilt_min": self._config.get(CONF_TILT_MIN),
|
||||
"tilt_max": self._config.get(CONF_TILT_MAX),
|
||||
}
|
||||
tilt = self._set_tilt_template(tilt, variables=variables)
|
||||
|
||||
await mqtt.async_publish(
|
||||
self.hass,
|
||||
|
@ -641,20 +639,18 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
|
||||
async def async_set_cover_position(self, **kwargs):
|
||||
"""Move the cover to a specific position."""
|
||||
template = self._config.get(CONF_SET_POSITION_TEMPLATE)
|
||||
position = kwargs[ATTR_POSITION]
|
||||
percentage_position = position
|
||||
position = self.find_in_range_from_percent(position, COVER_PAYLOAD)
|
||||
if template is not None:
|
||||
variables = {
|
||||
"position": percentage_position,
|
||||
"entity_id": self.entity_id,
|
||||
"position_open": self._config[CONF_POSITION_OPEN],
|
||||
"position_closed": self._config[CONF_POSITION_CLOSED],
|
||||
"tilt_min": self._config[CONF_TILT_MIN],
|
||||
"tilt_max": self._config[CONF_TILT_MAX],
|
||||
}
|
||||
position = template.async_render(parse_result=False, variables=variables)
|
||||
variables = {
|
||||
"position": percentage_position,
|
||||
"entity_id": self.entity_id,
|
||||
"position_open": self._config[CONF_POSITION_OPEN],
|
||||
"position_closed": self._config[CONF_POSITION_CLOSED],
|
||||
"tilt_min": self._config[CONF_TILT_MIN],
|
||||
"tilt_max": self._config[CONF_TILT_MAX],
|
||||
}
|
||||
position = self._set_position_template(position, variables=variables)
|
||||
|
||||
await mqtt.async_publish(
|
||||
self.hass,
|
||||
|
|
|
@ -36,7 +36,7 @@ from homeassistant.util.percentage import (
|
|||
ranged_value_to_percentage,
|
||||
)
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -332,13 +332,17 @@ class MqttFan(MqttEntity, FanEntity):
|
|||
if self._feature_preset_mode:
|
||||
self._supported_features |= SUPPORT_PRESET_MODE
|
||||
|
||||
for tpl_dict in (self._command_templates, self._value_templates):
|
||||
for key, tpl in tpl_dict.items():
|
||||
if tpl is None:
|
||||
tpl_dict[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
tpl_dict[key] = tpl.async_render_with_possible_json_value
|
||||
for key, tpl in self._command_templates.items():
|
||||
self._command_templates[key] = MqttCommandTemplate(
|
||||
tpl, self.hass
|
||||
).async_render
|
||||
|
||||
for key, tpl in self._value_templates.items():
|
||||
if tpl is None:
|
||||
self._value_templates[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
self._value_templates[key] = tpl.async_render_with_possible_json_value
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
|
|
@ -27,7 +27,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.helpers.reload import async_setup_reload_service
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -237,13 +237,17 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||
)
|
||||
self._optimistic_mode = optimistic or self._topic[CONF_MODE_STATE_TOPIC] is None
|
||||
|
||||
for tpl_dict in (self._command_templates, self._value_templates):
|
||||
for key, tpl in tpl_dict.items():
|
||||
if tpl is None:
|
||||
tpl_dict[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
tpl_dict[key] = tpl.async_render_with_possible_json_value
|
||||
for key, tpl in self._command_templates.items():
|
||||
self._command_templates[key] = MqttCommandTemplate(
|
||||
tpl, self.hass
|
||||
).async_render
|
||||
|
||||
for key, tpl in self._value_templates.items():
|
||||
if tpl is None:
|
||||
self._value_templates[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
self._value_templates[key] = tpl.async_render_with_possible_json_value
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
|
|
@ -25,7 +25,7 @@ from homeassistant.helpers.reload import async_setup_reload_service
|
|||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -138,15 +138,20 @@ class MqttNumber(MqttEntity, NumberEntity, RestoreEntity):
|
|||
self._optimistic = config[CONF_OPTIMISTIC]
|
||||
|
||||
self._templates = {
|
||||
CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE),
|
||||
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
|
||||
config.get(CONF_COMMAND_TEMPLATE), self.hass
|
||||
).async_render,
|
||||
CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE),
|
||||
}
|
||||
for key, tpl in self._templates.items():
|
||||
if tpl is None:
|
||||
self._templates[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
self._templates[key] = tpl.async_render_with_possible_json_value
|
||||
|
||||
value_template = self._templates[CONF_VALUE_TEMPLATE]
|
||||
if value_template is None:
|
||||
self._templates[CONF_VALUE_TEMPLATE] = lambda value: value
|
||||
else:
|
||||
value_template.hass = self.hass
|
||||
self._templates[
|
||||
CONF_VALUE_TEMPLATE
|
||||
] = value_template.async_render_with_possible_json_value
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
|
|
@ -13,7 +13,7 @@ from homeassistant.helpers.reload import async_setup_reload_service
|
|||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import PLATFORMS, subscription
|
||||
from . import PLATFORMS, MqttCommandTemplate, subscription
|
||||
from .. import mqtt
|
||||
from .const import CONF_COMMAND_TOPIC, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC, DOMAIN
|
||||
from .debug_info import log_messages
|
||||
|
@ -102,15 +102,20 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
|||
self._attr_options = config[CONF_OPTIONS]
|
||||
|
||||
self._templates = {
|
||||
CONF_COMMAND_TEMPLATE: config.get(CONF_COMMAND_TEMPLATE),
|
||||
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
|
||||
config.get(CONF_COMMAND_TEMPLATE), self.hass
|
||||
).async_render,
|
||||
CONF_VALUE_TEMPLATE: config.get(CONF_VALUE_TEMPLATE),
|
||||
}
|
||||
for key, tpl in self._templates.items():
|
||||
if tpl is None:
|
||||
self._templates[key] = lambda value: value
|
||||
else:
|
||||
tpl.hass = self.hass
|
||||
self._templates[key] = tpl.async_render_with_possible_json_value
|
||||
|
||||
value_template = self._templates[CONF_VALUE_TEMPLATE]
|
||||
if value_template is None:
|
||||
self._templates[CONF_VALUE_TEMPLATE] = lambda value: value
|
||||
else:
|
||||
value_template.hass = self.hass
|
||||
self._templates[
|
||||
CONF_VALUE_TEMPLATE
|
||||
] = value_template.async_render_with_possible_json_value
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
|
|
@ -18,7 +18,7 @@ from homeassistant.const import (
|
|||
)
|
||||
from homeassistant.core import CoreState, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers import device_registry as dr, template
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
|
@ -91,7 +91,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(hass, mqtt_mock):
|
|||
assert mqtt_mock.async_disconnect.called
|
||||
|
||||
|
||||
async def test_publish_(hass, mqtt_mock):
|
||||
async def test_publish(hass, mqtt_mock):
|
||||
"""Test the publish function."""
|
||||
await mqtt.async_publish(hass, "test-topic", "test-payload")
|
||||
await hass.async_block_till_done()
|
||||
|
@ -137,6 +137,57 @@ async def test_publish_(hass, mqtt_mock):
|
|||
)
|
||||
mqtt_mock.reset_mock()
|
||||
|
||||
# test binary pass-through
|
||||
mqtt.publish(
|
||||
hass,
|
||||
"test-topic3",
|
||||
b"\xde\xad\xbe\xef",
|
||||
0,
|
||||
False,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert mqtt_mock.async_publish.called
|
||||
assert mqtt_mock.async_publish.call_args[0] == (
|
||||
"test-topic3",
|
||||
b"\xde\xad\xbe\xef",
|
||||
0,
|
||||
False,
|
||||
)
|
||||
mqtt_mock.reset_mock()
|
||||
|
||||
|
||||
async def test_convert_outgoing_payload(hass):
|
||||
"""Test the converting of outgoing MQTT payloads without template."""
|
||||
command_template = mqtt.MqttCommandTemplate(None, hass)
|
||||
assert command_template.async_render(b"\xde\xad\xbe\xef") == b"\xde\xad\xbe\xef"
|
||||
|
||||
assert (
|
||||
command_template.async_render("b'\\xde\\xad\\xbe\\xef'")
|
||||
== "b'\\xde\\xad\\xbe\\xef'"
|
||||
)
|
||||
|
||||
assert command_template.async_render(1234) == 1234
|
||||
|
||||
assert command_template.async_render(1234.56) == 1234.56
|
||||
|
||||
assert command_template.async_render(None) is None
|
||||
|
||||
|
||||
async def test_command_template_value(hass):
|
||||
"""Test the rendering of MQTT command template."""
|
||||
|
||||
variables = {"id": 1234, "some_var": "beer"}
|
||||
|
||||
# test rendering value
|
||||
tpl = template.Template("{{ value + 1 }}", hass)
|
||||
cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass)
|
||||
assert cmd_tpl.async_render(4321) == "4322"
|
||||
|
||||
# test variables at rendering
|
||||
tpl = template.Template("{{ some_var }}", hass)
|
||||
cmd_tpl = mqtt.MqttCommandTemplate(tpl, hass)
|
||||
assert cmd_tpl.async_render(None, variables=variables) == "beer"
|
||||
|
||||
|
||||
async def test_service_call_without_topic_does_not_publish(hass, mqtt_mock):
|
||||
"""Test the service call if topic is missing."""
|
||||
|
@ -260,6 +311,20 @@ async def test_service_call_with_template_payload_renders_template(hass, mqtt_mo
|
|||
)
|
||||
assert mqtt_mock.async_publish.called
|
||||
assert mqtt_mock.async_publish.call_args[0][1] == "8"
|
||||
mqtt_mock.reset_mock()
|
||||
|
||||
await hass.services.async_call(
|
||||
mqtt.DOMAIN,
|
||||
mqtt.SERVICE_PUBLISH,
|
||||
{
|
||||
mqtt.ATTR_TOPIC: "test/topic",
|
||||
mqtt.ATTR_PAYLOAD_TEMPLATE: "{{ (4+4) | pack('B') }}",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
assert mqtt_mock.async_publish.called
|
||||
assert mqtt_mock.async_publish.call_args[0][1] == b"\x08"
|
||||
mqtt_mock.reset_mock()
|
||||
|
||||
|
||||
async def test_service_call_with_bad_template(hass, mqtt_mock):
|
||||
|
|
Loading…
Reference in New Issue