Improve MQTT type hints part 2 (#80529)

* Improve typing camera

* Improve typing cover

* b64 encoding can be either bytes or a string.
pull/76999/head
Jan Bouwhuis 2022-11-02 20:33:46 +01:00 committed by GitHub
parent b4ad03784f
commit bda7e416c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 47 deletions

View File

@ -27,6 +27,7 @@ from .mixins import (
async_setup_platform_helper,
warn_for_legacy_schema,
)
from .models import ReceiveMessage
from .util import valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -114,8 +115,8 @@ async def _async_setup_entity(
hass: HomeAssistant,
async_add_entities: AddEntitiesCallback,
config: ConfigType,
config_entry: ConfigEntry | None = None,
discovery_data: dict | None = None,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None = None,
) -> None:
"""Set up the MQTT Camera."""
async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)])
@ -124,31 +125,38 @@ async def _async_setup_entity(
class MqttCamera(MqttEntity, Camera):
"""representation of a MQTT camera."""
_entity_id_format = camera.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_CAMERA_ATTRIBUTES_BLOCKED
_entity_id_format: str = camera.ENTITY_ID_FORMAT
_attributes_extra_blocked: frozenset[str] = MQTT_CAMERA_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data):
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT Camera."""
self._last_image = None
self._last_image: bytes | None = None
Camera.__init__(self)
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod
def config_schema():
def config_schema() -> vol.Schema:
"""Return the config schema."""
return DISCOVERY_SCHEMA
def _prepare_subscribe_topics(self):
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg):
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
self._sub_state = subscription.async_prepare_subscribe_topics(
@ -164,7 +172,7 @@ class MqttCamera(MqttEntity, Camera):
},
)
async def _subscribe_topics(self):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)

View File

@ -31,6 +31,7 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription
@ -50,7 +51,7 @@ from .mixins import (
async_setup_platform_helper,
warn_for_legacy_schema,
)
from .models import MqttCommandTemplate, MqttValueTemplate
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -113,44 +114,44 @@ MQTT_COVER_ATTRIBUTES_BLOCKED = frozenset(
)
def validate_options(value):
def validate_options(config: ConfigType) -> ConfigType:
"""Validate options.
If set position topic is set then get position topic is set as well.
"""
if CONF_SET_POSITION_TOPIC in value and CONF_GET_POSITION_TOPIC not in value:
if CONF_SET_POSITION_TOPIC in config and CONF_GET_POSITION_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
)
# if templates are set make sure the topic for the template is also set
if CONF_VALUE_TEMPLATE in value and CONF_STATE_TOPIC not in value:
if CONF_VALUE_TEMPLATE in config and CONF_STATE_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'."
)
if CONF_GET_POSITION_TEMPLATE in value and CONF_GET_POSITION_TOPIC not in value:
if CONF_GET_POSITION_TEMPLATE in config and CONF_GET_POSITION_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
)
if CONF_SET_POSITION_TEMPLATE in value and CONF_SET_POSITION_TOPIC not in value:
if CONF_SET_POSITION_TEMPLATE in config and CONF_SET_POSITION_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'."
)
if CONF_TILT_COMMAND_TEMPLATE in value and CONF_TILT_COMMAND_TOPIC not in value:
if CONF_TILT_COMMAND_TEMPLATE in config and CONF_TILT_COMMAND_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'."
)
if CONF_TILT_STATUS_TEMPLATE in value and CONF_TILT_STATUS_TOPIC not in value:
if CONF_TILT_STATUS_TEMPLATE in config and CONF_TILT_STATUS_TOPIC not in config:
raise vol.Invalid(
f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'."
)
return value
return config
_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend(
@ -251,8 +252,8 @@ async def _async_setup_entity(
hass: HomeAssistant,
async_add_entities: AddEntitiesCallback,
config: ConfigType,
config_entry: ConfigEntry | None = None,
discovery_data: dict | None = None,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None = None,
) -> None:
"""Set up the MQTT Cover."""
async_add_entities([MqttCover(hass, config, config_entry, discovery_data)])
@ -261,26 +262,32 @@ async def _async_setup_entity(
class MqttCover(MqttEntity, CoverEntity):
"""Representation of a cover that can be controlled using MQTT."""
_entity_id_format = cover.ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_COVER_ATTRIBUTES_BLOCKED
_entity_id_format: str = cover.ENTITY_ID_FORMAT
_attributes_extra_blocked: frozenset[str] = MQTT_COVER_ATTRIBUTES_BLOCKED
def __init__(self, hass, config, config_entry, discovery_data):
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the cover."""
self._position = None
self._state = None
self._position: int | None = None
self._state: str | None = None
self._optimistic = None
self._tilt_value = None
self._tilt_optimistic = None
self._optimistic: bool | None = None
self._tilt_value: int | None = None
self._tilt_optimistic: bool | None = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod
def config_schema():
def config_schema() -> vol.Schema:
"""Return the config schema."""
return DISCOVERY_SCHEMA
def _setup_from_config(self, config):
def _setup_from_config(self, config: ConfigType) -> None:
no_position = (
config.get(CONF_SET_POSITION_TOPIC) is None
and config.get(CONF_GET_POSITION_TOPIC) is None
@ -353,13 +360,13 @@ class MqttCover(MqttEntity, CoverEntity):
config_attributes=template_config_attributes,
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self):
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
@callback
@log_messages(self.hass, self.entity_id)
def tilt_message_received(msg):
def tilt_message_received(msg: ReceiveMessage) -> None:
"""Handle tilt updates."""
payload = self._tilt_status_template(msg.payload)
@ -371,7 +378,7 @@ class MqttCover(MqttEntity, CoverEntity):
@callback
@log_messages(self.hass, self.entity_id)
def state_message_received(msg):
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
@ -409,31 +416,32 @@ class MqttCover(MqttEntity, CoverEntity):
@callback
@log_messages(self.hass, self.entity_id)
def position_message_received(msg):
def position_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT position messages."""
payload = self._get_position_template(msg.payload)
payload: ReceivePayloadType = self._get_position_template(msg.payload)
payload_dict: Any = None
if not payload:
_LOGGER.debug("Ignoring empty position message from '%s'", msg.topic)
return
try:
payload = json_loads(payload)
payload_dict = json_loads(payload)
except JSON_DECODE_EXCEPTIONS:
pass
if isinstance(payload, dict):
if "position" not in payload:
if payload_dict and isinstance(payload_dict, dict):
if "position" not in payload_dict:
_LOGGER.warning(
"Template (position_template) returned JSON without position attribute"
)
return
if "tilt_position" in payload:
if "tilt_position" in payload_dict:
if not self._config.get(CONF_TILT_STATE_OPTIMISTIC):
# reset forced set tilt optimistic
self._tilt_optimistic = False
self.tilt_payload_received(payload["tilt_position"])
payload = payload["position"]
self.tilt_payload_received(payload_dict["tilt_position"])
payload = payload_dict["position"]
try:
percentage_payload = self.find_percentage_in_range(
@ -481,7 +489,7 @@ class MqttCover(MqttEntity, CoverEntity):
self.hass, self._sub_state, topics
)
async def _subscribe_topics(self):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
@ -719,13 +727,15 @@ class MqttCover(MqttEntity, CoverEntity):
else:
await self.async_close_cover_tilt(**kwargs)
def is_tilt_closed(self):
def is_tilt_closed(self) -> bool:
"""Return if the cover is tilted closed."""
return self._tilt_value == self.find_percentage_in_range(
float(self._config[CONF_TILT_CLOSED_POSITION])
)
def find_percentage_in_range(self, position, range_type=TILT_PAYLOAD):
def find_percentage_in_range(
self, position: float, range_type: str = TILT_PAYLOAD
) -> int:
"""Find the 0-100% value within the specified range."""
# the range of motion as defined by the min max values
if range_type == COVER_PAYLOAD:
@ -745,7 +755,9 @@ class MqttCover(MqttEntity, CoverEntity):
return position_percentage
def find_in_range_from_percent(self, percentage, range_type=TILT_PAYLOAD):
def find_in_range_from_percent(
self, percentage: float, range_type: str = TILT_PAYLOAD
) -> int:
"""
Find the adjusted value for 0-100% within the specified range.
@ -768,7 +780,7 @@ class MqttCover(MqttEntity, CoverEntity):
return position
@callback
def tilt_payload_received(self, _payload):
def tilt_payload_received(self, _payload: Any) -> None:
"""Set the tilt value."""
try: