Improve MQTT type hints part 2 (#80529)
* Improve typing camera * Improve typing cover * b64 encoding can be either bytes or a string.pull/76999/head
parent
b4ad03784f
commit
bda7e416c4
|
@ -27,6 +27,7 @@ from .mixins import (
|
|||
async_setup_platform_helper,
|
||||
warn_for_legacy_schema,
|
||||
)
|
||||
from .models import ReceiveMessage
|
||||
from .util import valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -114,8 +115,8 @@ async def _async_setup_entity(
|
|||
hass: HomeAssistant,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry | None = None,
|
||||
discovery_data: dict | None = None,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the MQTT Camera."""
|
||||
async_add_entities([MqttCamera(hass, config, config_entry, discovery_data)])
|
||||
|
@ -124,31 +125,38 @@ async def _async_setup_entity(
|
|||
class MqttCamera(MqttEntity, Camera):
|
||||
"""representation of a MQTT camera."""
|
||||
|
||||
_entity_id_format = camera.ENTITY_ID_FORMAT
|
||||
_attributes_extra_blocked = MQTT_CAMERA_ATTRIBUTES_BLOCKED
|
||||
_entity_id_format: str = camera.ENTITY_ID_FORMAT
|
||||
_attributes_extra_blocked: frozenset[str] = MQTT_CAMERA_ATTRIBUTES_BLOCKED
|
||||
|
||||
def __init__(self, hass, config, config_entry, discovery_data):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None,
|
||||
) -> None:
|
||||
"""Initialize the MQTT Camera."""
|
||||
self._last_image = None
|
||||
self._last_image: bytes | None = None
|
||||
|
||||
Camera.__init__(self)
|
||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||
|
||||
@staticmethod
|
||||
def config_schema():
|
||||
def config_schema() -> vol.Schema:
|
||||
"""Return the config schema."""
|
||||
return DISCOVERY_SCHEMA
|
||||
|
||||
def _prepare_subscribe_topics(self):
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def message_received(msg):
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
if CONF_IMAGE_ENCODING in self._config:
|
||||
self._last_image = b64decode(msg.payload)
|
||||
else:
|
||||
assert isinstance(msg.payload, bytes)
|
||||
self._last_image = msg.payload
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
|
@ -164,7 +172,7 @@ class MqttCamera(MqttEntity, Camera):
|
|||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_loads
|
||||
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import subscription
|
||||
|
@ -50,7 +51,7 @@ from .mixins import (
|
|||
async_setup_platform_helper,
|
||||
warn_for_legacy_schema,
|
||||
)
|
||||
from .models import MqttCommandTemplate, MqttValueTemplate
|
||||
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
|
||||
from .util import get_mqtt_data, valid_publish_topic, valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -113,44 +114,44 @@ MQTT_COVER_ATTRIBUTES_BLOCKED = frozenset(
|
|||
)
|
||||
|
||||
|
||||
def validate_options(value):
|
||||
def validate_options(config: ConfigType) -> ConfigType:
|
||||
"""Validate options.
|
||||
|
||||
If set position topic is set then get position topic is set as well.
|
||||
"""
|
||||
if CONF_SET_POSITION_TOPIC in value and CONF_GET_POSITION_TOPIC not in value:
|
||||
if CONF_SET_POSITION_TOPIC in config and CONF_GET_POSITION_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_SET_POSITION_TOPIC}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
||||
)
|
||||
|
||||
# if templates are set make sure the topic for the template is also set
|
||||
|
||||
if CONF_VALUE_TEMPLATE in value and CONF_STATE_TOPIC not in value:
|
||||
if CONF_VALUE_TEMPLATE in config and CONF_STATE_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_VALUE_TEMPLATE}' must be set together with '{CONF_STATE_TOPIC}'."
|
||||
)
|
||||
|
||||
if CONF_GET_POSITION_TEMPLATE in value and CONF_GET_POSITION_TOPIC not in value:
|
||||
if CONF_GET_POSITION_TEMPLATE in config and CONF_GET_POSITION_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_GET_POSITION_TEMPLATE}' must be set together with '{CONF_GET_POSITION_TOPIC}'."
|
||||
)
|
||||
|
||||
if CONF_SET_POSITION_TEMPLATE in value and CONF_SET_POSITION_TOPIC not in value:
|
||||
if CONF_SET_POSITION_TEMPLATE in config and CONF_SET_POSITION_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_SET_POSITION_TEMPLATE}' must be set together with '{CONF_SET_POSITION_TOPIC}'."
|
||||
)
|
||||
|
||||
if CONF_TILT_COMMAND_TEMPLATE in value and CONF_TILT_COMMAND_TOPIC not in value:
|
||||
if CONF_TILT_COMMAND_TEMPLATE in config and CONF_TILT_COMMAND_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_TILT_COMMAND_TEMPLATE}' must be set together with '{CONF_TILT_COMMAND_TOPIC}'."
|
||||
)
|
||||
|
||||
if CONF_TILT_STATUS_TEMPLATE in value and CONF_TILT_STATUS_TOPIC not in value:
|
||||
if CONF_TILT_STATUS_TEMPLATE in config and CONF_TILT_STATUS_TOPIC not in config:
|
||||
raise vol.Invalid(
|
||||
f"'{CONF_TILT_STATUS_TEMPLATE}' must be set together with '{CONF_TILT_STATUS_TOPIC}'."
|
||||
)
|
||||
|
||||
return value
|
||||
return config
|
||||
|
||||
|
||||
_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend(
|
||||
|
@ -251,8 +252,8 @@ async def _async_setup_entity(
|
|||
hass: HomeAssistant,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry | None = None,
|
||||
discovery_data: dict | None = None,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None = None,
|
||||
) -> None:
|
||||
"""Set up the MQTT Cover."""
|
||||
async_add_entities([MqttCover(hass, config, config_entry, discovery_data)])
|
||||
|
@ -261,26 +262,32 @@ async def _async_setup_entity(
|
|||
class MqttCover(MqttEntity, CoverEntity):
|
||||
"""Representation of a cover that can be controlled using MQTT."""
|
||||
|
||||
_entity_id_format = cover.ENTITY_ID_FORMAT
|
||||
_attributes_extra_blocked = MQTT_COVER_ATTRIBUTES_BLOCKED
|
||||
_entity_id_format: str = cover.ENTITY_ID_FORMAT
|
||||
_attributes_extra_blocked: frozenset[str] = MQTT_COVER_ATTRIBUTES_BLOCKED
|
||||
|
||||
def __init__(self, hass, config, config_entry, discovery_data):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
config_entry: ConfigEntry,
|
||||
discovery_data: DiscoveryInfoType | None,
|
||||
) -> None:
|
||||
"""Initialize the cover."""
|
||||
self._position = None
|
||||
self._state = None
|
||||
self._position: int | None = None
|
||||
self._state: str | None = None
|
||||
|
||||
self._optimistic = None
|
||||
self._tilt_value = None
|
||||
self._tilt_optimistic = None
|
||||
self._optimistic: bool | None = None
|
||||
self._tilt_value: int | None = None
|
||||
self._tilt_optimistic: bool | None = None
|
||||
|
||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||
|
||||
@staticmethod
|
||||
def config_schema():
|
||||
def config_schema() -> vol.Schema:
|
||||
"""Return the config schema."""
|
||||
return DISCOVERY_SCHEMA
|
||||
|
||||
def _setup_from_config(self, config):
|
||||
def _setup_from_config(self, config: ConfigType) -> None:
|
||||
no_position = (
|
||||
config.get(CONF_SET_POSITION_TOPIC) is None
|
||||
and config.get(CONF_GET_POSITION_TOPIC) is None
|
||||
|
@ -353,13 +360,13 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
config_attributes=template_config_attributes,
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
def _prepare_subscribe_topics(self):
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def tilt_message_received(msg):
|
||||
def tilt_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle tilt updates."""
|
||||
payload = self._tilt_status_template(msg.payload)
|
||||
|
||||
|
@ -371,7 +378,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def state_message_received(msg):
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
|
||||
|
@ -409,31 +416,32 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def position_message_received(msg):
|
||||
def position_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT position messages."""
|
||||
payload = self._get_position_template(msg.payload)
|
||||
payload: ReceivePayloadType = self._get_position_template(msg.payload)
|
||||
payload_dict: Any = None
|
||||
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty position message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json_loads(payload)
|
||||
payload_dict = json_loads(payload)
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
pass
|
||||
|
||||
if isinstance(payload, dict):
|
||||
if "position" not in payload:
|
||||
if payload_dict and isinstance(payload_dict, dict):
|
||||
if "position" not in payload_dict:
|
||||
_LOGGER.warning(
|
||||
"Template (position_template) returned JSON without position attribute"
|
||||
)
|
||||
return
|
||||
if "tilt_position" in payload:
|
||||
if "tilt_position" in payload_dict:
|
||||
if not self._config.get(CONF_TILT_STATE_OPTIMISTIC):
|
||||
# reset forced set tilt optimistic
|
||||
self._tilt_optimistic = False
|
||||
self.tilt_payload_received(payload["tilt_position"])
|
||||
payload = payload["position"]
|
||||
self.tilt_payload_received(payload_dict["tilt_position"])
|
||||
payload = payload_dict["position"]
|
||||
|
||||
try:
|
||||
percentage_payload = self.find_percentage_in_range(
|
||||
|
@ -481,7 +489,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
self.hass, self._sub_state, topics
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self):
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
|
||||
|
@ -719,13 +727,15 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
else:
|
||||
await self.async_close_cover_tilt(**kwargs)
|
||||
|
||||
def is_tilt_closed(self):
|
||||
def is_tilt_closed(self) -> bool:
|
||||
"""Return if the cover is tilted closed."""
|
||||
return self._tilt_value == self.find_percentage_in_range(
|
||||
float(self._config[CONF_TILT_CLOSED_POSITION])
|
||||
)
|
||||
|
||||
def find_percentage_in_range(self, position, range_type=TILT_PAYLOAD):
|
||||
def find_percentage_in_range(
|
||||
self, position: float, range_type: str = TILT_PAYLOAD
|
||||
) -> int:
|
||||
"""Find the 0-100% value within the specified range."""
|
||||
# the range of motion as defined by the min max values
|
||||
if range_type == COVER_PAYLOAD:
|
||||
|
@ -745,7 +755,9 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
|
||||
return position_percentage
|
||||
|
||||
def find_in_range_from_percent(self, percentage, range_type=TILT_PAYLOAD):
|
||||
def find_in_range_from_percent(
|
||||
self, percentage: float, range_type: str = TILT_PAYLOAD
|
||||
) -> int:
|
||||
"""
|
||||
Find the adjusted value for 0-100% within the specified range.
|
||||
|
||||
|
@ -768,7 +780,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||
return position
|
||||
|
||||
@callback
|
||||
def tilt_payload_received(self, _payload):
|
||||
def tilt_payload_received(self, _payload: Any) -> None:
|
||||
"""Set the tilt value."""
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue