Add SSL support to TCP integration (#48060)
Co-authored-by: Michael <35783820+mib1185@users.noreply.github.com>pull/50690/head
parent
dab66a58ce
commit
0c37effc72
|
@ -2,6 +2,7 @@
|
|||
import logging
|
||||
import select
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -11,9 +12,11 @@ from homeassistant.const import (
|
|||
CONF_NAME,
|
||||
CONF_PAYLOAD,
|
||||
CONF_PORT,
|
||||
CONF_SSL,
|
||||
CONF_TIMEOUT,
|
||||
CONF_UNIT_OF_MEASUREMENT,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
CONF_VERIFY_SSL,
|
||||
)
|
||||
from homeassistant.exceptions import TemplateError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
@ -26,6 +29,8 @@ CONF_VALUE_ON = "value_on"
|
|||
DEFAULT_BUFFER_SIZE = 1024
|
||||
DEFAULT_NAME = "TCP Sensor"
|
||||
DEFAULT_TIMEOUT = 10
|
||||
DEFAULT_SSL = False
|
||||
DEFAULT_VERIFY_SSL = True
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
|
@ -38,6 +43,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||||
vol.Optional(CONF_VALUE_ON): cv.string,
|
||||
vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
|
||||
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -71,6 +78,15 @@ class TcpSensor(SensorEntity):
|
|||
CONF_VALUE_ON: config.get(CONF_VALUE_ON),
|
||||
CONF_BUFFER_SIZE: config.get(CONF_BUFFER_SIZE),
|
||||
}
|
||||
|
||||
if config[CONF_SSL]:
|
||||
self._ssl_context = ssl.create_default_context()
|
||||
if not config[CONF_VERIFY_SSL]:
|
||||
self._ssl_context.check_hostname = False
|
||||
self._ssl_context.verify_mode = ssl.CERT_NONE
|
||||
else:
|
||||
self._ssl_context = None
|
||||
|
||||
self._state = None
|
||||
self.update()
|
||||
|
||||
|
@ -104,6 +120,11 @@ class TcpSensor(SensorEntity):
|
|||
)
|
||||
return
|
||||
|
||||
if self._ssl_context is not None:
|
||||
sock = self._ssl_context.wrap_socket(
|
||||
sock, server_hostname=self._config[CONF_HOST]
|
||||
)
|
||||
|
||||
try:
|
||||
sock.send(self._config[CONF_PAYLOAD].encode())
|
||||
except OSError as err:
|
||||
|
|
|
@ -57,6 +57,18 @@ def mock_select_fixture():
|
|||
yield mock_select
|
||||
|
||||
|
||||
@pytest.fixture(name="mock_ssl_context")
|
||||
def mock_ssl_context_fixture():
|
||||
"""Mock select."""
|
||||
with patch(
|
||||
"homeassistant.components.tcp.sensor.ssl.create_default_context",
|
||||
) as mock_ssl_context:
|
||||
mock_ssl_context.return_value.wrap_socket.return_value.recv.return_value = (
|
||||
socket_test_value + "_ssl"
|
||||
).encode()
|
||||
yield mock_ssl_context
|
||||
|
||||
|
||||
async def test_setup_platform_valid_config(hass, mock_socket):
|
||||
"""Check a valid configuration and call add_entities with sensor."""
|
||||
with assert_setup_component(1, "sensor"):
|
||||
|
@ -159,3 +171,66 @@ async def test_update_returns_if_template_render_fails(hass, mock_socket):
|
|||
|
||||
assert state
|
||||
assert state.state == "unknown"
|
||||
|
||||
|
||||
async def test_ssl_state(hass, mock_socket, mock_select, mock_ssl_context):
|
||||
"""Return the contents of _state, updated over SSL."""
|
||||
config = copy(SENSOR_TEST_CONFIG)
|
||||
config[tcp.CONF_SSL] = "on"
|
||||
|
||||
assert await async_setup_component(hass, "sensor", {"sensor": config})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(TEST_ENTITY)
|
||||
|
||||
assert state
|
||||
assert state.state == "test_value_ssl"
|
||||
assert mock_socket.connect.called
|
||||
assert mock_socket.connect.call_args == call(
|
||||
(SENSOR_TEST_CONFIG["host"], SENSOR_TEST_CONFIG["port"])
|
||||
)
|
||||
assert not mock_socket.send.called
|
||||
assert mock_ssl_context.called
|
||||
assert mock_ssl_context.return_value.check_hostname
|
||||
mock_ssl_socket = mock_ssl_context.return_value.wrap_socket.return_value
|
||||
assert mock_ssl_socket.send.called
|
||||
assert mock_ssl_socket.send.call_args == call(
|
||||
SENSOR_TEST_CONFIG["payload"].encode()
|
||||
)
|
||||
assert mock_select.call_args == call(
|
||||
[mock_ssl_socket], [], [], SENSOR_TEST_CONFIG[tcp.CONF_TIMEOUT]
|
||||
)
|
||||
assert mock_ssl_socket.recv.called
|
||||
assert mock_ssl_socket.recv.call_args == call(SENSOR_TEST_CONFIG["buffer_size"])
|
||||
|
||||
|
||||
async def test_ssl_state_verify_off(hass, mock_socket, mock_select, mock_ssl_context):
|
||||
"""Return the contents of _state, updated over SSL (verify_ssl disabled)."""
|
||||
config = copy(SENSOR_TEST_CONFIG)
|
||||
config[tcp.CONF_SSL] = "on"
|
||||
config[tcp.CONF_VERIFY_SSL] = "off"
|
||||
|
||||
assert await async_setup_component(hass, "sensor", {"sensor": config})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get(TEST_ENTITY)
|
||||
|
||||
assert state
|
||||
assert state.state == "test_value_ssl"
|
||||
assert mock_socket.connect.called
|
||||
assert mock_socket.connect.call_args == call(
|
||||
(SENSOR_TEST_CONFIG["host"], SENSOR_TEST_CONFIG["port"])
|
||||
)
|
||||
assert not mock_socket.send.called
|
||||
assert mock_ssl_context.called
|
||||
assert not mock_ssl_context.return_value.check_hostname
|
||||
mock_ssl_socket = mock_ssl_context.return_value.wrap_socket.return_value
|
||||
assert mock_ssl_socket.send.called
|
||||
assert mock_ssl_socket.send.call_args == call(
|
||||
SENSOR_TEST_CONFIG["payload"].encode()
|
||||
)
|
||||
assert mock_select.call_args == call(
|
||||
[mock_ssl_socket], [], [], SENSOR_TEST_CONFIG[tcp.CONF_TIMEOUT]
|
||||
)
|
||||
assert mock_ssl_socket.recv.called
|
||||
assert mock_ssl_socket.recv.call_args == call(SENSOR_TEST_CONFIG["buffer_size"])
|
||||
|
|
Loading…
Reference in New Issue