Additional SSL validation checks for cert_expiry (#28047)
* Additional SSL validation checks * Add validity attribute, log errors on import * Don't log from sensorpull/28139/head
parent
a644182b5e
commit
44bf9e9ddc
|
@ -1,5 +1,7 @@
|
|||
"""Config flow for the Cert Expiry platform."""
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
|
@ -9,6 +11,8 @@ from homeassistant.core import HomeAssistant, callback
|
|||
from .const import DOMAIN, DEFAULT_PORT, DEFAULT_NAME
|
||||
from .helper import get_cert
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def certexpiry_entries(hass: HomeAssistant):
|
||||
|
@ -39,17 +43,28 @@ class CertexpiryConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
async def _test_connection(self, user_input=None):
|
||||
"""Test connection to the server and try to get the certtificate."""
|
||||
host = user_input[CONF_HOST]
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
get_cert, user_input[CONF_HOST], user_input.get(CONF_PORT, DEFAULT_PORT)
|
||||
get_cert, host, user_input.get(CONF_PORT, DEFAULT_PORT)
|
||||
)
|
||||
return True
|
||||
except socket.gaierror:
|
||||
_LOGGER.error("Host cannot be resolved: %s", host)
|
||||
self._errors[CONF_HOST] = "resolve_failed"
|
||||
except socket.timeout:
|
||||
_LOGGER.error("Timed out connecting to %s", host)
|
||||
self._errors[CONF_HOST] = "connection_timeout"
|
||||
except OSError:
|
||||
self._errors[CONF_HOST] = "certificate_fetch_failed"
|
||||
except ssl.CertificateError as err:
|
||||
if "doesn't match" in err.args[0]:
|
||||
_LOGGER.error("Certificate does not match host: %s", host)
|
||||
self._errors[CONF_HOST] = "wrong_host"
|
||||
else:
|
||||
_LOGGER.error("Certificate could not be validated: %s", host)
|
||||
self._errors[CONF_HOST] = "certificate_error"
|
||||
except ssl.SSLError:
|
||||
_LOGGER.error("Certificate could not be validated: %s", host)
|
||||
self._errors[CONF_HOST] = "certificate_error"
|
||||
return False
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
|
|
|
@ -70,6 +70,7 @@ class SSLCertificate(Entity):
|
|||
self._name = sensor_name
|
||||
self._state = None
|
||||
self._available = False
|
||||
self._valid = False
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
@ -122,16 +123,17 @@ class SSLCertificate(Entity):
|
|||
except socket.gaierror:
|
||||
_LOGGER.error("Cannot resolve hostname: %s", self.server_name)
|
||||
self._available = False
|
||||
self._valid = False
|
||||
return
|
||||
except socket.timeout:
|
||||
_LOGGER.error("Connection timeout with server: %s", self.server_name)
|
||||
self._available = False
|
||||
self._valid = False
|
||||
return
|
||||
except OSError:
|
||||
_LOGGER.error(
|
||||
"Cannot fetch certificate from %s", self.server_name, exc_info=1
|
||||
)
|
||||
self._available = False
|
||||
except (ssl.CertificateError, ssl.SSLError):
|
||||
self._available = True
|
||||
self._state = 0
|
||||
self._valid = False
|
||||
return
|
||||
|
||||
ts_seconds = ssl.cert_time_to_seconds(cert["notAfter"])
|
||||
|
@ -139,3 +141,11 @@ class SSLCertificate(Entity):
|
|||
expiry = timestamp - datetime.today()
|
||||
self._available = True
|
||||
self._state = expiry.days
|
||||
self._valid = True
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
"""Return additional sensor state attributes."""
|
||||
attr = {"is_valid": self._valid}
|
||||
|
||||
return attr
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
"host_port_exists": "This host and port combination is already configured",
|
||||
"resolve_failed": "This host can not be resolved",
|
||||
"connection_timeout": "Timeout when connecting to this host",
|
||||
"certificate_fetch_failed": "Can not fetch certificate from this host and port combination"
|
||||
"certificate_error": "Certificate could not be validated",
|
||||
"wrong_host": "Certificate does not match hostname"
|
||||
},
|
||||
"abort": {
|
||||
"host_port_exists": "This host and port combination is already configured"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for the Cert Expiry config flow."""
|
||||
import pytest
|
||||
import ssl
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -131,7 +132,22 @@ async def test_abort_on_socket_failed(hass):
|
|||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["errors"] == {CONF_HOST: "connection_timeout"}
|
||||
|
||||
with patch("socket.create_connection", side_effect=OSError()):
|
||||
with patch(
|
||||
"socket.create_connection",
|
||||
side_effect=ssl.CertificateError(f"{HOST} doesn't match somethingelse.com"),
|
||||
):
|
||||
result = await flow.async_step_user({CONF_HOST: HOST})
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["errors"] == {CONF_HOST: "certificate_fetch_failed"}
|
||||
assert result["errors"] == {CONF_HOST: "wrong_host"}
|
||||
|
||||
with patch(
|
||||
"socket.create_connection", side_effect=ssl.CertificateError("different error")
|
||||
):
|
||||
result = await flow.async_step_user({CONF_HOST: HOST})
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["errors"] == {CONF_HOST: "certificate_error"}
|
||||
|
||||
with patch("socket.create_connection", side_effect=ssl.SSLError()):
|
||||
result = await flow.async_step_user({CONF_HOST: HOST})
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["errors"] == {CONF_HOST: "certificate_error"}
|
||||
|
|
Loading…
Reference in New Issue