Abort execution of template renders that overwhelm the system (#40647)

pull/40708/head
J. Nick Koston 2020-09-28 07:43:22 -05:00 committed by GitHub
parent 8bc62f3678
commit e08ee282ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 212 additions and 15 deletions

View File

@ -239,22 +239,32 @@ def handle_ping(hass, connection, msg):
connection.send_message(pong_message(msg["id"]))
@callback
@decorators.websocket_command(
{
vol.Required("type"): "render_template",
vol.Required("template"): str,
vol.Optional("entity_ids"): cv.entity_ids,
vol.Optional("variables"): dict,
vol.Optional("timeout"): vol.Coerce(float),
}
)
def handle_render_template(hass, connection, msg):
@decorators.async_response
async def handle_render_template(hass, connection, msg):
"""Handle render_template command."""
template_str = msg["template"]
template = Template(template_str, hass)
variables = msg.get("variables")
timeout = msg.get("timeout")
info = None
if timeout and await template.async_render_will_timeout(timeout):
connection.send_error(
msg["id"],
const.ERR_TEMPLATE_ERROR,
f"Exceeded maximum execution time of {timeout}s",
)
return
@callback
def _template_listener(event, updates):
nonlocal info

View File

@ -1,4 +1,5 @@
"""Template helper methods for rendering strings with Home Assistant data."""
import asyncio
import base64
import collections.abc
from datetime import datetime, timedelta
@ -36,6 +37,7 @@ from homeassistant.helpers.typing import HomeAssistantType, TemplateVarsType
from homeassistant.loader import bind_hass
from homeassistant.util import convert, dt as dt_util, location as loc_util
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException
# mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any
@ -309,6 +311,54 @@ class Template:
except jinja2.TemplateError as err:
raise TemplateError(err) from err
async def async_render_will_timeout(
self, timeout: float, variables: TemplateVarsType = None, **kwargs: Any
) -> bool:
"""Check to see if rendering a template will timeout during render.
This is intended to check for expensive templates
that will make the system unstable. The template
is rendered in the executor to ensure it does not
tie up the event loop.
This function is not a security control and is only
intended to be used as a safety check when testing
templates.
This method must be run in the event loop.
"""
assert self.hass
if self.is_static:
return False
compiled = self._compiled or self._ensure_compiled()
if variables is not None:
kwargs.update(variables)
finish_event = asyncio.Event()
def _render_template():
try:
compiled.render(kwargs)
except TimeoutError:
pass
finally:
run_callback_threadsafe(self.hass.loop, finish_event.set)
try:
template_render_thread = ThreadWithException(target=_render_template)
template_render_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=timeout)
except asyncio.TimeoutError:
template_render_thread.raise_exc(TimeoutError)
return True
finally:
template_render_thread.join()
return False
@callback
def async_render_to_info(
self, variables: TemplateVarsType = None, **kwargs: Any

View File

@ -1,4 +1,6 @@
"""Threading util helpers."""
import ctypes
import inspect
import sys
import threading
from typing import Any
@ -24,3 +26,34 @@ def fix_threading_exception_logging() -> None:
sys.excepthook(*sys.exc_info())
threading.Thread.run = run # type: ignore
def _async_raise(tid: int, exctype: Any) -> None:
"""Raise an exception in the threads with id tid."""
if not inspect.isclass(exctype):
raise TypeError("Only types can be raised (not instances)")
c_tid = ctypes.c_long(tid)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, ctypes.py_object(exctype))
if res == 1:
return
# "if it returns a number greater than one, you're in trouble,
# and you should call it again with exc=NULL to revert the effect"
ctypes.pythonapi.PyThreadState_SetAsyncExc(c_tid, None)
raise SystemError("PyThreadState_SetAsyncExc failed")
class ThreadWithException(threading.Thread):
"""A thread class that supports raising exception in the thread from another thread.
Based on
https://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread/49877671
"""
def raise_exc(self, exctype: Any) -> None:
"""Raise the given exception type in the context of this thread."""
assert self.ident
_async_raise(self.ident, exctype)

View File

@ -397,9 +397,7 @@ async def test_subscribe_unsubscribe_events_state_changed(
assert msg["event"]["data"]["entity_id"] == "light.permitted"
async def test_render_template_renders_template(
hass, websocket_client, hass_admin_user
):
async def test_render_template_renders_template(hass, websocket_client):
"""Test simple template is rendered and updated."""
hass.states.async_set("light.test", "on")
@ -437,7 +435,7 @@ async def test_render_template_renders_template(
async def test_render_template_manual_entity_ids_no_longer_needed(
hass, websocket_client, hass_admin_user
hass, websocket_client
):
"""Test that updates to specified entity ids cause a template rerender."""
hass.states.async_set("light.test", "on")
@ -475,9 +473,7 @@ async def test_render_template_manual_entity_ids_no_longer_needed(
}
async def test_render_template_with_error(
hass, websocket_client, hass_admin_user, caplog
):
async def test_render_template_with_error(hass, websocket_client, caplog):
"""Test a template with an error."""
await websocket_client.send_json(
{"id": 5, "type": "render_template", "template": "{{ my_unknown_var() + 1 }}"}
@ -492,9 +488,7 @@ async def test_render_template_with_error(
assert "TemplateError" not in caplog.text
async def test_render_template_with_delayed_error(
hass, websocket_client, hass_admin_user, caplog
):
async def test_render_template_with_delayed_error(hass, websocket_client, caplog):
"""Test a template with an error that only happens after a state change."""
hass.states.async_set("sensor.test", "on")
await hass.async_block_till_done()
@ -539,9 +533,36 @@ async def test_render_template_with_delayed_error(
assert "TemplateError" not in caplog.text
async def test_render_template_returns_with_match_all(
hass, websocket_client, hass_admin_user
):
async def test_render_template_with_timeout(hass, websocket_client, caplog):
"""Test a template that will timeout."""
slow_template_str = """
{% for var in range(1000) -%}
{% for var in range(1000) -%}
{{ var }}
{%- endfor %}
{%- endfor %}
"""
await websocket_client.send_json(
{
"id": 5,
"type": "render_template",
"timeout": 0.000001,
"template": slow_template_str,
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 5
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_TEMPLATE_ERROR
assert "TemplateError" not in caplog.text
async def test_render_template_returns_with_match_all(hass, websocket_client):
"""Test that a template that would match with all entities still return success."""
await websocket_client.send_json(
{"id": 5, "type": "render_template", "template": "State is: {{ 42 }}"}

View File

@ -2455,3 +2455,31 @@ async def test_lifecycle(hass):
assert info.filter("sensor.sensor1") is False
assert info.filter_lifecycle("sensor.new") is True
assert info.filter_lifecycle("sensor.removed") is True
async def test_template_timeout(hass):
"""Test to see if a template will timeout."""
for i in range(2):
hass.states.async_set(f"sensor.sensor{i}", "on")
tmp = template.Template("{{ states | count }}", hass)
assert await tmp.async_render_will_timeout(3) is False
tmp2 = template.Template("{{ error_invalid + 1 }}", hass)
assert await tmp2.async_render_will_timeout(3) is False
tmp3 = template.Template("static", hass)
assert await tmp3.async_render_will_timeout(3) is False
tmp4 = template.Template("{{ var1 }}", hass)
assert await tmp4.async_render_will_timeout(3, {"var1": "ok"}) is False
slow_template_str = """
{% for var in range(1000) -%}
{% for var in range(1000) -%}
{{ var }}
{%- endfor %}
{%- endfor %}
"""
tmp5 = template.Template(slow_template_str, hass)
assert await tmp5.async_render_will_timeout(0.000001) is True

55
tests/util/test_thread.py Normal file
View File

@ -0,0 +1,55 @@
"""Test Home Assistant thread utils."""
import asyncio
import pytest
from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.thread import ThreadWithException
async def test_thread_with_exception_invalid(hass):
"""Test throwing an invalid thread exception."""
finish_event = asyncio.Event()
def _do_nothing(*_):
run_callback_threadsafe(hass.loop, finish_event.set)
test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
with pytest.raises(TypeError):
test_thread.raise_exc(_EmptyClass())
test_thread.join()
async def test_thread_not_started(hass):
"""Test throwing when the thread is not started."""
test_thread = ThreadWithException(target=lambda *_: None)
with pytest.raises(AssertionError):
test_thread.raise_exc(TimeoutError)
async def test_thread_fails_raise(hass):
"""Test throwing after already ended."""
finish_event = asyncio.Event()
def _do_nothing(*_):
run_callback_threadsafe(hass.loop, finish_event.set)
test_thread = ThreadWithException(target=_do_nothing)
test_thread.start()
await asyncio.wait_for(finish_event.wait(), timeout=0.1)
test_thread.join()
with pytest.raises(SystemError):
test_thread.raise_exc(ValueError)
class _EmptyClass:
"""An empty class."""