Improve certificate handling in MQTT config flow (#137234)
* Improve mqtt broker certificate handling in config flow * Expand test casespull/139598/head
parent
dd21d48ae4
commit
913a4ee9ba
|
@ -5,14 +5,21 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable, Mapping
|
||||
from enum import IntEnum
|
||||
import logging
|
||||
import queue
|
||||
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||
from cryptography.x509 import load_pem_x509_certificate
|
||||
from cryptography.hazmat.primitives.serialization import (
|
||||
Encoding,
|
||||
NoEncryption,
|
||||
PrivateFormat,
|
||||
load_der_private_key,
|
||||
load_pem_private_key,
|
||||
)
|
||||
from cryptography.x509 import load_der_x509_certificate, load_pem_x509_certificate
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.file_upload import process_uploaded_file
|
||||
|
@ -105,6 +112,8 @@ _LOGGER = logging.getLogger(__name__)
|
|||
ADDON_SETUP_TIMEOUT = 5
|
||||
ADDON_SETUP_TIMEOUT_ROUNDS = 5
|
||||
|
||||
CONF_CLIENT_KEY_PASSWORD = "client_key_password"
|
||||
|
||||
MQTT_TIMEOUT = 5
|
||||
|
||||
ADVANCED_OPTIONS = "advanced_options"
|
||||
|
@ -165,12 +174,14 @@ BROKER_VERIFICATION_SELECTOR = SelectSelector(
|
|||
|
||||
# mime configuration from https://pki-tutorial.readthedocs.io/en/latest/mime.html
|
||||
CA_CERT_UPLOAD_SELECTOR = FileSelector(
|
||||
FileSelectorConfig(accept=".crt,application/x-x509-ca-cert")
|
||||
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-ca-cert")
|
||||
)
|
||||
CERT_UPLOAD_SELECTOR = FileSelector(
|
||||
FileSelectorConfig(accept=".crt,application/x-x509-user-cert")
|
||||
FileSelectorConfig(accept=".pem,.crt,.cer,.der,application/x-x509-user-cert")
|
||||
)
|
||||
KEY_UPLOAD_SELECTOR = FileSelector(
|
||||
FileSelectorConfig(accept=".pem,.key,.der,.pk8,application/pkcs8")
|
||||
)
|
||||
KEY_UPLOAD_SELECTOR = FileSelector(FileSelectorConfig(accept=".key,application/pkcs8"))
|
||||
|
||||
REAUTH_SCHEMA = vol.Schema(
|
||||
{
|
||||
|
@ -710,17 +721,88 @@ class MQTTOptionsFlowHandler(OptionsFlow):
|
|||
)
|
||||
|
||||
|
||||
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> str:
|
||||
"""Get file content from uploaded file."""
|
||||
@callback
|
||||
def async_is_pem_data(data: bytes) -> bool:
|
||||
"""Return True if data is in PEM format."""
|
||||
return (
|
||||
b"-----BEGIN CERTIFICATE-----" in data
|
||||
or b"-----BEGIN PRIVATE KEY-----" in data
|
||||
or b"-----BEGIN RSA PRIVATE KEY-----" in data
|
||||
or b"-----BEGIN ENCRYPTED PRIVATE KEY-----" in data
|
||||
)
|
||||
|
||||
def _proces_uploaded_file() -> str:
|
||||
|
||||
class PEMType(IntEnum):
|
||||
"""Type of PEM data."""
|
||||
|
||||
CERTIFICATE = 1
|
||||
PRIVATE_KEY = 2
|
||||
|
||||
|
||||
@callback
|
||||
def async_convert_to_pem(
|
||||
data: bytes, pem_type: PEMType, password: str | None = None
|
||||
) -> str | None:
|
||||
"""Convert data to PEM format."""
|
||||
try:
|
||||
if async_is_pem_data(data):
|
||||
if not password:
|
||||
# Assume unencrypted PEM encoded private key
|
||||
return data.decode(DEFAULT_ENCODING)
|
||||
# Return decrypted PEM encoded private key
|
||||
return (
|
||||
load_pem_private_key(data, password=password.encode(DEFAULT_ENCODING))
|
||||
.private_bytes(
|
||||
encoding=Encoding.PEM,
|
||||
format=PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=NoEncryption(),
|
||||
)
|
||||
.decode(DEFAULT_ENCODING)
|
||||
)
|
||||
# Convert from DER encoding to PEM
|
||||
if pem_type == PEMType.CERTIFICATE:
|
||||
return (
|
||||
load_der_x509_certificate(data)
|
||||
.public_bytes(
|
||||
encoding=Encoding.PEM,
|
||||
)
|
||||
.decode(DEFAULT_ENCODING)
|
||||
)
|
||||
# Assume DER encoded private key
|
||||
pem_key_data: bytes = load_der_private_key(
|
||||
data, password.encode(DEFAULT_ENCODING) if password else None
|
||||
).private_bytes(
|
||||
encoding=Encoding.PEM,
|
||||
format=PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=NoEncryption(),
|
||||
)
|
||||
return pem_key_data.decode("utf-8")
|
||||
except (TypeError, ValueError, SSLError):
|
||||
_LOGGER.exception("Error converting %s file data to PEM format", pem_type.name)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_uploaded_file(hass: HomeAssistant, id: str) -> bytes:
|
||||
"""Get file content from uploaded certificate or key file."""
|
||||
|
||||
def _proces_uploaded_file() -> bytes:
|
||||
with process_uploaded_file(hass, id) as file_path:
|
||||
return file_path.read_text(encoding=DEFAULT_ENCODING)
|
||||
return file_path.read_bytes()
|
||||
|
||||
return await hass.async_add_executor_job(_proces_uploaded_file)
|
||||
|
||||
|
||||
async def async_get_broker_settings(
|
||||
def _validate_pki_file(
|
||||
file_id: str | None, pem_data: str | None, errors: dict[str, str], error: str
|
||||
) -> bool:
|
||||
"""Return False if uploaded file could not be converted to PEM format."""
|
||||
if file_id and not pem_data:
|
||||
errors["base"] = error
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def async_get_broker_settings( # noqa: C901
|
||||
flow: ConfigFlow | OptionsFlow,
|
||||
fields: OrderedDict[Any, Any],
|
||||
entry_config: MappingProxyType[str, Any] | None,
|
||||
|
@ -768,6 +850,10 @@ async def async_get_broker_settings(
|
|||
validated_user_input.update(user_input)
|
||||
client_certificate_id: str | None = user_input.get(CONF_CLIENT_CERT)
|
||||
client_key_id: str | None = user_input.get(CONF_CLIENT_KEY)
|
||||
# We do not store the private key password in the entry data
|
||||
client_key_password: str | None = validated_user_input.pop(
|
||||
CONF_CLIENT_KEY_PASSWORD, None
|
||||
)
|
||||
if (client_certificate_id and not client_key_id) or (
|
||||
not client_certificate_id and client_key_id
|
||||
):
|
||||
|
@ -775,7 +861,14 @@ async def async_get_broker_settings(
|
|||
return False
|
||||
certificate_id: str | None = user_input.get(CONF_CERTIFICATE)
|
||||
if certificate_id:
|
||||
certificate = await _get_uploaded_file(hass, certificate_id)
|
||||
certificate_data_raw = await _get_uploaded_file(hass, certificate_id)
|
||||
certificate = async_convert_to_pem(
|
||||
certificate_data_raw, PEMType.CERTIFICATE
|
||||
)
|
||||
if not _validate_pki_file(
|
||||
certificate_id, certificate, errors, "bad_certificate"
|
||||
):
|
||||
return False
|
||||
|
||||
# Return to form for file upload CA cert or client cert and key
|
||||
if (
|
||||
|
@ -797,9 +890,26 @@ async def async_get_broker_settings(
|
|||
return False
|
||||
|
||||
if client_certificate_id:
|
||||
client_certificate = await _get_uploaded_file(hass, client_certificate_id)
|
||||
client_certificate_data = await _get_uploaded_file(
|
||||
hass, client_certificate_id
|
||||
)
|
||||
client_certificate = async_convert_to_pem(
|
||||
client_certificate_data, PEMType.CERTIFICATE
|
||||
)
|
||||
if not _validate_pki_file(
|
||||
client_certificate_id, client_certificate, errors, "bad_client_cert"
|
||||
):
|
||||
return False
|
||||
|
||||
if client_key_id:
|
||||
client_key = await _get_uploaded_file(hass, client_key_id)
|
||||
client_key_data = await _get_uploaded_file(hass, client_key_id)
|
||||
client_key = async_convert_to_pem(
|
||||
client_key_data, PEMType.PRIVATE_KEY, password=client_key_password
|
||||
)
|
||||
if not _validate_pki_file(
|
||||
client_key_id, client_key, errors, "client_key_error"
|
||||
):
|
||||
return False
|
||||
|
||||
certificate_data: dict[str, Any] = {}
|
||||
if certificate:
|
||||
|
@ -956,6 +1066,14 @@ async def async_get_broker_settings(
|
|||
description={"suggested_value": user_input_basic.get(CONF_CLIENT_KEY)},
|
||||
)
|
||||
] = KEY_UPLOAD_SELECTOR
|
||||
fields[
|
||||
vol.Optional(
|
||||
CONF_CLIENT_KEY_PASSWORD,
|
||||
description={
|
||||
"suggested_value": user_input_basic.get(CONF_CLIENT_KEY_PASSWORD)
|
||||
},
|
||||
)
|
||||
] = PASSWORD_SELECTOR
|
||||
verification_mode = current_config.get(SET_CA_CERT) or (
|
||||
"off"
|
||||
if current_ca_certificate is None
|
||||
|
@ -1060,7 +1178,7 @@ def check_certicate_chain() -> str | None:
|
|||
with open(private_key, "rb") as client_key_file:
|
||||
load_pem_private_key(client_key_file.read(), password=None)
|
||||
except (TypeError, ValueError):
|
||||
return "bad_client_key"
|
||||
return "client_key_error"
|
||||
# Check the certificate chain
|
||||
context = SSLContext(PROTOCOL_TLS_CLIENT)
|
||||
if client_certificate and private_key:
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
"client_id": "Client ID (leave empty to randomly generated one)",
|
||||
"client_cert": "Upload client certificate file",
|
||||
"client_key": "Upload private key file",
|
||||
"client_key_password": "[%key:common::config_flow::data::password%]",
|
||||
"keepalive": "The time between sending keep alive messages",
|
||||
"tls_insecure": "Ignore broker certificate validation",
|
||||
"protocol": "MQTT protocol",
|
||||
|
@ -45,6 +46,7 @@
|
|||
"client_id": "The unique ID to identify the Home Assistant MQTT API as MQTT client. It is recommended to leave this option blank.",
|
||||
"client_cert": "The client certificate to authenticate against your MQTT broker.",
|
||||
"client_key": "The private key file that belongs to your client certificate.",
|
||||
"client_key_password": "The password for the private key file (if set).",
|
||||
"keepalive": "A value less than 90 seconds is advised.",
|
||||
"tls_insecure": "Option to ignore validation of your MQTT broker's certificate.",
|
||||
"protocol": "The MQTT protocol your broker operates at. For example 3.1.1.",
|
||||
|
@ -93,8 +95,8 @@
|
|||
"bad_will": "Invalid will topic",
|
||||
"bad_discovery_prefix": "Invalid discovery prefix",
|
||||
"bad_certificate": "The CA certificate is invalid",
|
||||
"bad_client_cert": "Invalid client certificate, ensure a PEM coded file is supplied",
|
||||
"bad_client_key": "Invalid private key, ensure a PEM coded file is supplied without password",
|
||||
"bad_client_cert": "Invalid client certificate, ensure a valid file is supplied",
|
||||
"client_key_error": "Invalid private key file or invalid password supplied",
|
||||
"bad_client_cert_key": "Client certificate and private key are not a valid pair",
|
||||
"bad_ws_headers": "Supply valid HTTP headers as a JSON object",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
|
@ -207,7 +209,7 @@
|
|||
"bad_discovery_prefix": "[%key:component::mqtt::config::error::bad_discovery_prefix%]",
|
||||
"bad_certificate": "[%key:component::mqtt::config::error::bad_certificate%]",
|
||||
"bad_client_cert": "[%key:component::mqtt::config::error::bad_client_cert%]",
|
||||
"bad_client_key": "[%key:component::mqtt::config::error::bad_client_key%]",
|
||||
"client_key_error": "[%key:component::mqtt::config::error::client_key_error%]",
|
||||
"bad_client_cert_key": "[%key:component::mqtt::config::error::bad_client_cert_key%]",
|
||||
"bad_ws_headers": "[%key:component::mqtt::config::error::bad_ws_headers%]",
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
|
|
|
@ -40,8 +40,37 @@ ADD_ON_DISCOVERY_INFO = {
|
|||
"protocol": "3.1.1",
|
||||
"ssl": False,
|
||||
}
|
||||
MOCK_CLIENT_CERT = b"## mock client certificate file ##"
|
||||
MOCK_CLIENT_KEY = b"## mock key file ##"
|
||||
|
||||
MOCK_CA_CERT = (
|
||||
b"-----BEGIN CERTIFICATE-----\n"
|
||||
b"## mock CA certificate file ##"
|
||||
b"\n-----END CERTIFICATE-----\n"
|
||||
)
|
||||
MOCK_GENERIC_CERT = (
|
||||
b"-----BEGIN CERTIFICATE-----\n"
|
||||
b"## mock generic certificate file ##"
|
||||
b"\n-----END CERTIFICATE-----\n"
|
||||
)
|
||||
MOCK_CA_CERT_DER = b"## mock DER formatted CA certificate file ##\n"
|
||||
MOCK_CLIENT_CERT = (
|
||||
b"-----BEGIN CERTIFICATE-----\n"
|
||||
b"## mock client certificate file ##"
|
||||
b"\n-----END CERTIFICATE-----\n"
|
||||
)
|
||||
MOCK_CLIENT_CERT_DER = b"## mock DER formatted client certificate file ##\n"
|
||||
MOCK_CLIENT_KEY = (
|
||||
b"-----BEGIN PRIVATE KEY-----\n"
|
||||
b"## mock client key file ##"
|
||||
b"\n-----END PRIVATE KEY-----"
|
||||
)
|
||||
MOCK_ENCRYPTED_CLIENT_KEY = (
|
||||
b"-----BEGIN ENCRYPTED PRIVATE KEY-----\n"
|
||||
b"## mock client key file ##\n"
|
||||
b"-----END ENCRYPTED PRIVATE KEY-----"
|
||||
)
|
||||
MOCK_CLIENT_KEY_DER = b"## mock DER formatted key file ##\n"
|
||||
MOCK_ENCRYPTED_CLIENT_KEY_DER = b"## mock DER formatted encrypted key file ##\n"
|
||||
|
||||
|
||||
MOCK_ENTRY_DATA = {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
|
@ -102,15 +131,27 @@ def mock_ssl_context() -> Generator[dict[str, MagicMock]]:
|
|||
patch("homeassistant.components.mqtt.config_flow.SSLContext") as mock_context,
|
||||
patch(
|
||||
"homeassistant.components.mqtt.config_flow.load_pem_private_key"
|
||||
) as mock_key_check,
|
||||
) as mock_pem_key_check,
|
||||
patch(
|
||||
"homeassistant.components.mqtt.config_flow.load_der_private_key"
|
||||
) as mock_der_key_check,
|
||||
patch(
|
||||
"homeassistant.components.mqtt.config_flow.load_pem_x509_certificate"
|
||||
) as mock_cert_check,
|
||||
) as mock_pem_cert_check,
|
||||
patch(
|
||||
"homeassistant.components.mqtt.config_flow.load_der_x509_certificate"
|
||||
) as mock_der_cert_check,
|
||||
):
|
||||
mock_pem_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
|
||||
mock_pem_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||
mock_der_key_check().private_bytes.return_value = MOCK_CLIENT_KEY
|
||||
mock_der_cert_check().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||
yield {
|
||||
"context": mock_context,
|
||||
"load_pem_x509_certificate": mock_cert_check,
|
||||
"load_pem_private_key": mock_key_check,
|
||||
"load_der_private_key": mock_der_key_check,
|
||||
"load_der_x509_certificate": mock_der_cert_check,
|
||||
"load_pem_private_key": mock_pem_key_check,
|
||||
"load_pem_x509_certificate": mock_pem_cert_check,
|
||||
}
|
||||
|
||||
|
||||
|
@ -180,9 +221,31 @@ def mock_try_connection_time_out() -> Generator[MagicMock]:
|
|||
yield mock_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ca_cert() -> bytes:
|
||||
"""Mock the CA certificate."""
|
||||
return MOCK_CA_CERT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_cert() -> bytes:
|
||||
"""Mock the client certificate."""
|
||||
return MOCK_CLIENT_CERT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client_key() -> bytes:
|
||||
"""Mock the client key."""
|
||||
return MOCK_CLIENT_KEY
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_process_uploaded_file(
|
||||
tmp_path: Path, mock_temp_dir: str
|
||||
tmp_path: Path,
|
||||
mock_ca_cert: bytes,
|
||||
mock_client_cert: bytes,
|
||||
mock_client_key: bytes,
|
||||
mock_temp_dir: str,
|
||||
) -> Generator[MagicMock]:
|
||||
"""Mock upload certificate files."""
|
||||
file_id_ca = str(uuid4())
|
||||
|
@ -195,15 +258,15 @@ def mock_process_uploaded_file(
|
|||
) -> Iterator[Path | None]:
|
||||
if file_id == file_id_ca:
|
||||
with open(tmp_path / "ca.crt", "wb") as cafile:
|
||||
cafile.write(b"## mock CA certificate file ##")
|
||||
cafile.write(mock_ca_cert)
|
||||
yield tmp_path / "ca.crt"
|
||||
elif file_id == file_id_cert:
|
||||
with open(tmp_path / "client.crt", "wb") as certfile:
|
||||
certfile.write(b"## mock client certificate file ##")
|
||||
certfile.write(mock_client_cert)
|
||||
yield tmp_path / "client.crt"
|
||||
elif file_id == file_id_key:
|
||||
with open(tmp_path / "client.key", "wb") as keyfile:
|
||||
keyfile.write(b"## mock key file ##")
|
||||
keyfile.write(mock_client_key)
|
||||
yield tmp_path / "client.key"
|
||||
else:
|
||||
pytest.fail(f"Unexpected file_id: {file_id}")
|
||||
|
@ -1024,12 +1087,37 @@ async def test_option_flow(
|
|||
assert yaml_mock.await_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
|
||||
[
|
||||
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
|
||||
(
|
||||
MOCK_GENERIC_CERT,
|
||||
MOCK_GENERIC_CERT,
|
||||
MOCK_ENCRYPTED_CLIENT_KEY,
|
||||
"very*secret",
|
||||
),
|
||||
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
|
||||
(
|
||||
MOCK_CA_CERT_DER,
|
||||
MOCK_CLIENT_CERT_DER,
|
||||
MOCK_ENCRYPTED_CLIENT_KEY_DER,
|
||||
"very*secret",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"pem_certs_private_key_no_password",
|
||||
"pem_certs_private_key_with_password",
|
||||
"der_certs_private_key_no_password",
|
||||
"der_certs_private_key_with_password",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"test_error",
|
||||
[
|
||||
"bad_certificate",
|
||||
"bad_client_cert",
|
||||
"bad_client_key",
|
||||
"client_key_error",
|
||||
"bad_client_cert_key",
|
||||
"invalid_inclusion",
|
||||
None,
|
||||
|
@ -1042,31 +1130,54 @@ async def test_bad_certificate(
|
|||
mock_ssl_context: dict[str, MagicMock],
|
||||
mock_process_uploaded_file: MagicMock,
|
||||
test_error: str | None,
|
||||
client_key_password: str,
|
||||
mock_ca_cert: bytes,
|
||||
) -> None:
|
||||
"""Test bad certificate tests."""
|
||||
|
||||
def _side_effect_on_client_cert(data: bytes) -> MagicMock:
|
||||
"""Raise on client cert only.
|
||||
|
||||
The function is called twice, once for the CA chain
|
||||
and once for the client cert. We only want to raise on a client cert.
|
||||
"""
|
||||
if data == MOCK_CLIENT_CERT_DER:
|
||||
raise ValueError
|
||||
mock_certificate_side_effect = MagicMock()
|
||||
mock_certificate_side_effect().public_bytes.return_value = MOCK_GENERIC_CERT
|
||||
return mock_certificate_side_effect
|
||||
|
||||
# Mock certificate files
|
||||
file_id = mock_process_uploaded_file.file_id
|
||||
set_ca_cert = "custom"
|
||||
set_client_cert = True
|
||||
tls_insecure = False
|
||||
test_input = {
|
||||
mqtt.CONF_BROKER: "another-broker",
|
||||
CONF_PORT: 2345,
|
||||
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
|
||||
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
|
||||
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
|
||||
"set_ca_cert": True,
|
||||
"client_key_password": client_key_password,
|
||||
"set_ca_cert": set_ca_cert,
|
||||
"set_client_cert": True,
|
||||
}
|
||||
set_client_cert = True
|
||||
set_ca_cert = "custom"
|
||||
tls_insecure = False
|
||||
if test_error == "bad_certificate":
|
||||
# CA chain is not loading
|
||||
mock_ssl_context["context"]().load_verify_locations.side_effect = SSLError
|
||||
# Fail on the CA cert if DER encoded
|
||||
mock_ssl_context["load_der_x509_certificate"].side_effect = ValueError
|
||||
elif test_error == "bad_client_cert":
|
||||
# Client certificate is invalid
|
||||
mock_ssl_context["load_pem_x509_certificate"].side_effect = ValueError
|
||||
elif test_error == "bad_client_key":
|
||||
# Fail on the client cert if DER encoded
|
||||
mock_ssl_context[
|
||||
"load_der_x509_certificate"
|
||||
].side_effect = _side_effect_on_client_cert
|
||||
elif test_error == "client_key_error":
|
||||
# Client key file is invalid
|
||||
mock_ssl_context["load_pem_private_key"].side_effect = ValueError
|
||||
mock_ssl_context["load_der_private_key"].side_effect = ValueError
|
||||
elif test_error == "bad_client_cert_key":
|
||||
# Client key file file and certificate do not pair
|
||||
mock_ssl_context["context"]().load_cert_chain.side_effect = SSLError
|
||||
|
@ -2078,8 +2189,8 @@ async def test_setup_with_advanced_settings(
|
|||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "secret",
|
||||
mqtt.CONF_KEEPALIVE: 30,
|
||||
mqtt.CONF_CLIENT_CERT: "## mock client certificate file ##",
|
||||
mqtt.CONF_CLIENT_KEY: "## mock key file ##",
|
||||
mqtt.CONF_CLIENT_CERT: MOCK_CLIENT_CERT.decode(encoding="utf-8"),
|
||||
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
|
||||
"tls_insecure": True,
|
||||
mqtt.CONF_TRANSPORT: "websockets",
|
||||
mqtt.CONF_WS_PATH: "/custom_path/",
|
||||
|
@ -2091,6 +2202,155 @@ async def test_setup_with_advanced_settings(
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_ssl_context")
|
||||
@pytest.mark.parametrize(
|
||||
("mock_ca_cert", "mock_client_cert", "mock_client_key", "client_key_password"),
|
||||
[
|
||||
(MOCK_GENERIC_CERT, MOCK_GENERIC_CERT, MOCK_CLIENT_KEY, ""),
|
||||
(
|
||||
MOCK_GENERIC_CERT,
|
||||
MOCK_GENERIC_CERT,
|
||||
MOCK_ENCRYPTED_CLIENT_KEY,
|
||||
"very*secret",
|
||||
),
|
||||
(MOCK_CA_CERT_DER, MOCK_CLIENT_CERT_DER, MOCK_CLIENT_KEY_DER, ""),
|
||||
(
|
||||
MOCK_CA_CERT_DER,
|
||||
MOCK_CLIENT_CERT_DER,
|
||||
MOCK_ENCRYPTED_CLIENT_KEY_DER,
|
||||
"very*secret",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"pem_certs_private_key_no_password",
|
||||
"pem_certs_private_key_with_password",
|
||||
"der_certs_private_key_no_password",
|
||||
"der_certs_private_key_with_password",
|
||||
],
|
||||
)
|
||||
async def test_setup_with_certificates(
|
||||
hass: HomeAssistant,
|
||||
mock_try_connection: MagicMock,
|
||||
mock_process_uploaded_file: MagicMock,
|
||||
client_key_password: str,
|
||||
) -> None:
|
||||
"""Test config flow setup with PEM and DER encoded certificates."""
|
||||
file_id = mock_process_uploaded_file.file_id
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain=mqtt.DOMAIN,
|
||||
version=mqtt.CONFIG_ENTRY_VERSION,
|
||||
minor_version=mqtt.CONFIG_ENTRY_MINOR_VERSION,
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
hass.config_entries.async_update_entry(
|
||||
config_entry,
|
||||
data={
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
CONF_PORT: 1234,
|
||||
},
|
||||
)
|
||||
|
||||
mock_try_connection.return_value = True
|
||||
|
||||
result = await config_entry.start_reconfigure_flow(hass, show_advanced_options=True)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "broker"
|
||||
assert result["data_schema"].schema["advanced_options"]
|
||||
|
||||
# first iteration, basic settings
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
CONF_PORT: 2345,
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "secret",
|
||||
"advanced_options": True,
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "broker"
|
||||
assert "advanced_options" not in result["data_schema"].schema
|
||||
assert result["data_schema"].schema[CONF_CLIENT_ID]
|
||||
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
|
||||
assert result["data_schema"].schema["set_client_cert"]
|
||||
assert result["data_schema"].schema["set_ca_cert"]
|
||||
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
|
||||
assert result["data_schema"].schema[CONF_PROTOCOL]
|
||||
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
|
||||
assert mqtt.CONF_CLIENT_CERT not in result["data_schema"].schema
|
||||
assert mqtt.CONF_CLIENT_KEY not in result["data_schema"].schema
|
||||
|
||||
# second iteration, advanced settings with request for client cert
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
CONF_PORT: 2345,
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "secret",
|
||||
mqtt.CONF_KEEPALIVE: 30,
|
||||
"set_ca_cert": "custom",
|
||||
"set_client_cert": True,
|
||||
mqtt.CONF_TLS_INSECURE: False,
|
||||
CONF_PROTOCOL: "3.1.1",
|
||||
mqtt.CONF_TRANSPORT: "tcp",
|
||||
},
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "broker"
|
||||
assert "advanced_options" not in result["data_schema"].schema
|
||||
assert result["data_schema"].schema[CONF_CLIENT_ID]
|
||||
assert result["data_schema"].schema[mqtt.CONF_KEEPALIVE]
|
||||
assert result["data_schema"].schema["set_client_cert"]
|
||||
assert result["data_schema"].schema["set_ca_cert"]
|
||||
assert result["data_schema"].schema["client_key_password"]
|
||||
assert result["data_schema"].schema[mqtt.CONF_TLS_INSECURE]
|
||||
assert result["data_schema"].schema[CONF_PROTOCOL]
|
||||
assert result["data_schema"].schema[mqtt.CONF_CERTIFICATE]
|
||||
assert result["data_schema"].schema[mqtt.CONF_CLIENT_CERT]
|
||||
assert result["data_schema"].schema[mqtt.CONF_CLIENT_KEY]
|
||||
assert result["data_schema"].schema[mqtt.CONF_TRANSPORT]
|
||||
|
||||
# third iteration, advanced settings with client cert and key and CA certificate
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
CONF_PORT: 2345,
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "secret",
|
||||
mqtt.CONF_KEEPALIVE: 30,
|
||||
"set_ca_cert": "custom",
|
||||
"set_client_cert": True,
|
||||
"client_key_password": client_key_password,
|
||||
mqtt.CONF_CERTIFICATE: file_id[mqtt.CONF_CERTIFICATE],
|
||||
mqtt.CONF_CLIENT_CERT: file_id[mqtt.CONF_CLIENT_CERT],
|
||||
mqtt.CONF_CLIENT_KEY: file_id[mqtt.CONF_CLIENT_KEY],
|
||||
mqtt.CONF_TLS_INSECURE: False,
|
||||
mqtt.CONF_TRANSPORT: "tcp",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reconfigure_successful"
|
||||
|
||||
# Check config entry result
|
||||
assert config_entry.data == {
|
||||
mqtt.CONF_BROKER: "test-broker",
|
||||
CONF_PORT: 2345,
|
||||
CONF_USERNAME: "user",
|
||||
CONF_PASSWORD: "secret",
|
||||
mqtt.CONF_KEEPALIVE: 30,
|
||||
mqtt.CONF_CLIENT_CERT: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
|
||||
mqtt.CONF_CLIENT_KEY: MOCK_CLIENT_KEY.decode(encoding="utf-8"),
|
||||
"tls_insecure": False,
|
||||
mqtt.CONF_TRANSPORT: "tcp",
|
||||
mqtt.CONF_CERTIFICATE: MOCK_GENERIC_CERT.decode(encoding="utf-8"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_ssl_context", "mock_process_uploaded_file")
|
||||
async def test_change_websockets_transport_to_tcp(
|
||||
hass: HomeAssistant, mock_try_connection: MagicMock
|
||||
|
|
Loading…
Reference in New Issue