Abort execution of template renders that overwhelm the system (#40647)
parent
8bc62f3678
commit
e08ee282ab
|
@ -239,22 +239,32 @@ def handle_ping(hass, connection, msg):
|
|||
connection.send_message(pong_message(msg["id"]))
|
||||
|
||||
|
||||
@callback
|
||||
@decorators.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "render_template",
|
||||
vol.Required("template"): str,
|
||||
vol.Optional("entity_ids"): cv.entity_ids,
|
||||
vol.Optional("variables"): dict,
|
||||
vol.Optional("timeout"): vol.Coerce(float),
|
||||
}
|
||||
)
|
||||
def handle_render_template(hass, connection, msg):
|
||||
@decorators.async_response
|
||||
async def handle_render_template(hass, connection, msg):
|
||||
"""Handle render_template command."""
|
||||
template_str = msg["template"]
|
||||
template = Template(template_str, hass)
|
||||
variables = msg.get("variables")
|
||||
timeout = msg.get("timeout")
|
||||
info = None
|
||||
|
||||
if timeout and await template.async_render_will_timeout(timeout):
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
const.ERR_TEMPLATE_ERROR,
|
||||
f"Exceeded maximum execution time of {timeout}s",
|
||||
)
|
||||
return
|
||||
|
||||
@callback
|
||||
def _template_listener(event, updates):
|
||||
nonlocal info
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Template helper methods for rendering strings with Home Assistant data."""
|
||||
import asyncio
|
||||
import base64
|
||||
import collections.abc
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -36,6 +37,7 @@ from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType
|
|||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import convert, dt as dt_util, location as loc_util
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
# mypy: no-check-untyped-defs, no-warn-return-any
|
||||
|
@ -309,6 +311,54 @@ class Template:
|
|||
except jinja2.TemplateError as err:
|
||||
raise TemplateError(err) from err
|
||||
|
||||
async def async_render_will_timeout(
|
||||
self, timeout: float, variables: TemplateVarsType = None, **kwargs: Any
|
||||
) -> bool:
|
||||
"""Check to see if rendering a template will timeout during render.
|
||||
|
||||
This is intended to check for expensive templates
|
||||
that will make the system unstable. The template
|
||||
is rendered in the executor to ensure it does not
|
||||
tie up the event loop.
|
||||
|
||||
This function is not a security control and is only
|
||||
intended to be used as a safety check when testing
|
||||
templates.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
assert self.hass
|
||||
|
||||
if self.is_static:
|
||||
return False
|
||||
|
||||
compiled = self._compiled or self._ensure_compiled()
|
||||
|
||||
if variables is not None:
|
||||
kwargs.update(variables)
|
||||
|
||||
finish_event = asyncio.Event()
|
||||
|
||||
def _render_template():
|
||||
try:
|
||||
compiled.render(kwargs)
|
||||
except TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
run_callback_threadsafe(self.hass.loop, finish_event.set)
|
||||
|
||||
try:
|
||||
template_render_thread = ThreadWithException(target=_render_template)
|
||||
template_render_thread.start()
|
||||
await asyncio.wait_for(finish_event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
template_render_thread.raise_exc(TimeoutError)
|
||||
return True
|
||||
finally:
|
||||
template_render_thread.join()
|
||||
|
||||
return False
|
||||
|
||||
@callback
|
||||
def async_render_to_info(
|
||||
self, variables: TemplateVarsType = None, **kwargs: Any
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Threading util helpers."""
|
||||
import ctypes
|
||||
import inspect
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
@ -24,3 +26,34 @@ def fix_threading_exception_logging() -> None:
|
|||
sys.excepthook(*sys.exc_info())
|
||||
|
||||
threading.Thread.run = run # type: ignore
|
||||
|
||||
|
||||
def _async_raise(tid: int, exctype: Any) -> None:
|
||||
"""Raise an exception in the threads with id tid."""
|
||||
if not inspect.isclass(exctype):
|
||||
raise TypeError("Only types can be raised (not instances)")
|
||||
|
||||
c_tid = ctypes.c_long(tid)
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype))
|
||||
|
||||
if res == 1:
|
||||
return
|
||||
|
||||
# "if it returns a number greater than one, you're in trouble,
|
||||
# and you should call it again with exc=NULL to revert the effect"
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, None)
|
||||
raise SystemError("PyThreadState_SetAsyncExc failed")
|
||||
|
||||
|
||||
class ThreadWithException(threading.Thread):
|
||||
"""A thread class that supports raising exception in the thread from another thread.
|
||||
|
||||
Based on
|
||||
https://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread/49877671
|
||||
|
||||
"""
|
||||
|
||||
def raise_exc(self, exctype: Any) -> None:
|
||||
"""Raise the given exception type in the context of this thread."""
|
||||
assert self.ident
|
||||
_async_raise(self.ident, exctype)
|
||||
|
|
|
@ -397,9 +397,7 @@ async def test_subscribe_unsubscribe_events_state_changed(
|
|||
assert msg["event"]["data"]["entity_id"] == "light.permitted"
|
||||
|
||||
|
||||
async def test_render_template_renders_template(
|
||||
hass, websocket_client, hass_admin_user
|
||||
):
|
||||
async def test_render_template_renders_template(hass, websocket_client):
|
||||
"""Test simple template is rendered and updated."""
|
||||
hass.states.async_set("light.test", "on")
|
||||
|
||||
|
@ -437,7 +435,7 @@ async def test_render_template_renders_template(
|
|||
|
||||
|
||||
async def test_render_template_manual_entity_ids_no_longer_needed(
|
||||
hass, websocket_client, hass_admin_user
|
||||
hass, websocket_client
|
||||
):
|
||||
"""Test that updates to specified entity ids cause a template rerender."""
|
||||
hass.states.async_set("light.test", "on")
|
||||
|
@ -475,9 +473,7 @@ async def test_render_template_manual_entity_ids_no_longer_needed(
|
|||
}
|
||||
|
||||
|
||||
async def test_render_template_with_error(
|
||||
hass, websocket_client, hass_admin_user, caplog
|
||||
):
|
||||
async def test_render_template_with_error(hass, websocket_client, caplog):
|
||||
"""Test a template with an error."""
|
||||
await websocket_client.send_json(
|
||||
{"id": 5, "type": "render_template", "template": "{{ my_unknown_var() + 1 }}"}
|
||||
|
@ -492,9 +488,7 @@ async def test_render_template_with_error(
|
|||
assert "TemplateError" not in caplog.text
|
||||
|
||||
|
||||
async def test_render_template_with_delayed_error(
|
||||
hass, websocket_client, hass_admin_user, caplog
|
||||
):
|
||||
async def test_render_template_with_delayed_error(hass, websocket_client, caplog):
|
||||
"""Test a template with an error that only happens after a state change."""
|
||||
hass.states.async_set("sensor.test", "on")
|
||||
await hass.async_block_till_done()
|
||||
|
@ -539,9 +533,36 @@ async def test_render_template_with_delayed_error(
|
|||
assert "TemplateError" not in caplog.text
|
||||
|
||||
|
||||
async def test_render_template_returns_with_match_all(
|
||||
hass, websocket_client, hass_admin_user
|
||||
):
|
||||
async def test_render_template_with_timeout(hass, websocket_client, caplog):
|
||||
"""Test a template that will timeout."""
|
||||
|
||||
slow_template_str = """
|
||||
{% for var in range(1000) -%}
|
||||
{% for var in range(1000) -%}
|
||||
{{ var }}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"""
|
||||
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 5,
|
||||
"type": "render_template",
|
||||
"timeout": 0.000001,
|
||||
"template": slow_template_str,
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 5
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == const.ERR_TEMPLATE_ERROR
|
||||
|
||||
assert "TemplateError" not in caplog.text
|
||||
|
||||
|
||||
async def test_render_template_returns_with_match_all(hass, websocket_client):
|
||||
"""Test that a template that would match with all entities still return success."""
|
||||
await websocket_client.send_json(
|
||||
{"id": 5, "type": "render_template", "template": "State is: {{ 42 }}"}
|
||||
|
|
|
@ -2455,3 +2455,31 @@ async def test_lifecycle(hass):
|
|||
assert info.filter("sensor.sensor1") is False
|
||||
assert info.filter_lifecycle("sensor.new") is True
|
||||
assert info.filter_lifecycle("sensor.removed") is True
|
||||
|
||||
|
||||
async def test_template_timeout(hass):
|
||||
"""Test to see if a template will timeout."""
|
||||
for i in range(2):
|
||||
hass.states.async_set(f"sensor.sensor{i}", "on")
|
||||
|
||||
tmp = template.Template("{{ states | count }}", hass)
|
||||
assert await tmp.async_render_will_timeout(3) is False
|
||||
|
||||
tmp2 = template.Template("{{ error_invalid + 1 }}", hass)
|
||||
assert await tmp2.async_render_will_timeout(3) is False
|
||||
|
||||
tmp3 = template.Template("static", hass)
|
||||
assert await tmp3.async_render_will_timeout(3) is False
|
||||
|
||||
tmp4 = template.Template("{{ var1 }}", hass)
|
||||
assert await tmp4.async_render_will_timeout(3, {"var1": "ok"}) is False
|
||||
|
||||
slow_template_str = """
|
||||
{% for var in range(1000) -%}
|
||||
{% for var in range(1000) -%}
|
||||
{{ var }}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
"""
|
||||
tmp5 = template.Template(slow_template_str, hass)
|
||||
assert await tmp5.async_render_will_timeout(0.000001) is True
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
"""Test Home Assistant thread utils."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
from homeassistant.util.thread import ThreadWithException
|
||||
|
||||
|
||||
async def test_thread_with_exception_invalid(hass):
|
||||
"""Test throwing an invalid thread exception."""
|
||||
|
||||
finish_event = asyncio.Event()
|
||||
|
||||
def _do_nothing(*_):
|
||||
run_callback_threadsafe(hass.loop, finish_event.set)
|
||||
|
||||
test_thread = ThreadWithException(target=_do_nothing)
|
||||
test_thread.start()
|
||||
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
test_thread.raise_exc(_EmptyClass())
|
||||
test_thread.join()
|
||||
|
||||
|
||||
async def test_thread_not_started(hass):
|
||||
"""Test throwing when the thread is not started."""
|
||||
|
||||
test_thread = ThreadWithException(target=lambda *_: None)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
test_thread.raise_exc(TimeoutError)
|
||||
|
||||
|
||||
async def test_thread_fails_raise(hass):
|
||||
"""Test throwing after already ended."""
|
||||
|
||||
finish_event = asyncio.Event()
|
||||
|
||||
def _do_nothing(*_):
|
||||
run_callback_threadsafe(hass.loop, finish_event.set)
|
||||
|
||||
test_thread = ThreadWithException(target=_do_nothing)
|
||||
test_thread.start()
|
||||
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
|
||||
test_thread.join()
|
||||
|
||||
with pytest.raises(SystemError):
|
||||
test_thread.raise_exc(ValueError)
|
||||
|
||||
|
||||
class _EmptyClass:
|
||||
"""An empty class."""
|
Loading…
Reference in New Issue