diff --git a/homeassistant/components/mysensors/__init__.py b/homeassistant/components/mysensors/__init__.py index 9d23cfd24b6..3f36d6e96cc 100644 --- a/homeassistant/components/mysensors/__init__.py +++ b/homeassistant/components/mysensors/__init__.py @@ -42,7 +42,7 @@ from .const import ( DevId, SensorType, ) -from .device import MySensorsDevice, MySensorsEntity, get_mysensors_devices +from .device import MySensorsDevice, get_mysensors_devices from .gateway import finish_setup, get_mysensors_gateway, gw_stop, setup_gateway from .helpers import on_unload @@ -271,7 +271,7 @@ def setup_mysensors_platform( hass: HomeAssistant, domain: str, # hass platform name discovery_info: dict[str, list[DevId]], - device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsEntity]], + device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsDevice]], device_args: ( None | tuple ) = None, # extra arguments that will be given to the entity constructor @@ -302,11 +302,13 @@ def setup_mysensors_platform( if not gateway: _LOGGER.warning("Skipping setup of %s, no gateway found", dev_id) continue - device_class_copy = device_class + if isinstance(device_class, dict): child = gateway.sensors[node_id].children[child_id] s_type = gateway.const.Presentation(child.type).name device_class_copy = device_class[s_type] + else: + device_class_copy = device_class args_copy = (*device_args, gateway_id, gateway, node_id, child_id, value_type) devices[dev_id] = device_class_copy(*args_copy) diff --git a/homeassistant/components/mysensors/config_flow.py b/homeassistant/components/mysensors/config_flow.py index ad260c3ab58..223d27a2a60 100644 --- a/homeassistant/components/mysensors/config_flow.py +++ b/homeassistant/components/mysensors/config_flow.py @@ -27,7 +27,7 @@ from homeassistant.components.mysensors import ( ) from homeassistant.config_entries import ConfigEntry from homeassistant.core import callback -from homeassistant.data_entry_flow import RESULT_TYPE_FORM, FlowResult +from homeassistant.data_entry_flow import FlowResult import homeassistant.helpers.config_validation as cv from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION @@ -111,7 +111,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Set up config flow.""" self._gw_type: str | None = None - async def async_step_import(self, user_input: dict[str, str] | None = None): + async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult: """Import a config entry. This method is called by async_setup and it has already @@ -131,12 +131,14 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): else: user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL - result: dict[str, Any] = await self.async_step_user(user_input=user_input) - if result["type"] == RESULT_TYPE_FORM: - return self.async_abort(reason=next(iter(result["errors"].values()))) + result: FlowResult = await self.async_step_user(user_input=user_input) + if errors := result.get("errors"): + return self.async_abort(reason=next(iter(errors.values()))) return result - async def async_step_user(self, user_input: dict[str, str] | None = None): + async def async_step_user( + self, user_input: dict[str, str] | None = None + ) -> FlowResult: """Create a config entry from frontend user input.""" schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)} schema = vol.Schema(schema) @@ -158,9 +160,11 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form(step_id="user", data_schema=schema, errors=errors) - async def async_step_gw_serial(self, user_input: dict[str, str] | None = None): + async def async_step_gw_serial( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: """Create config entry for a serial gateway.""" - errors = {} + errors: dict[str, str] = {} if user_input is not None: errors.update( await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input) @@ -187,7 +191,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): step_id="gw_serial", data_schema=schema, errors=errors ) - async def async_step_gw_tcp(self, user_input: dict[str, str] | None = None): + async def async_step_gw_tcp( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: """Create a config entry for a tcp gateway.""" errors = {} if user_input is not None: @@ -225,7 +231,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): return True return False - async def async_step_gw_mqtt(self, user_input: dict[str, str] | None = None): + async def async_step_gw_mqtt( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: """Create a config entry for a mqtt gateway.""" errors = {} if user_input is not None: @@ -280,9 +288,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) @callback - def _async_create_entry( - self, user_input: dict[str, str] | None = None - ) -> FlowResult: + def _async_create_entry(self, user_input: dict[str, Any]) -> FlowResult: """Create the config entry.""" return self.async_create_entry( title=f"{user_input[CONF_DEVICE]}", @@ -296,55 +302,52 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self, gw_type: ConfGatewayType, errors: dict[str, str], - user_input: dict[str, str] | None = None, + user_input: dict[str, Any], ) -> dict[str, str]: """Validate parameters common to all gateway types.""" - if user_input is not None: - errors.update(_validate_version(user_input.get(CONF_VERSION))) + errors.update(_validate_version(user_input[CONF_VERSION])) - if gw_type != CONF_GATEWAY_TYPE_MQTT: - if gw_type == CONF_GATEWAY_TYPE_TCP: - verification_func = is_socket_address - else: - verification_func = is_serial_port + if gw_type != CONF_GATEWAY_TYPE_MQTT: + if gw_type == CONF_GATEWAY_TYPE_TCP: + verification_func = is_socket_address + else: + verification_func = is_serial_port - try: - await self.hass.async_add_executor_job( - verification_func, user_input.get(CONF_DEVICE) - ) - except vol.Invalid: - errors[CONF_DEVICE] = ( - "invalid_ip" - if gw_type == CONF_GATEWAY_TYPE_TCP - else "invalid_serial" - ) - if CONF_PERSISTENCE_FILE in user_input: - try: - is_persistence_file(user_input[CONF_PERSISTENCE_FILE]) - except vol.Invalid: - errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file" - else: - real_persistence_path = user_input[ - CONF_PERSISTENCE_FILE - ] = self._normalize_persistence_file( - user_input[CONF_PERSISTENCE_FILE] - ) - for other_entry in self._async_current_entries(): - if CONF_PERSISTENCE_FILE not in other_entry.data: - continue - if real_persistence_path == self._normalize_persistence_file( - other_entry.data[CONF_PERSISTENCE_FILE] - ): - errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file" - break + try: + await self.hass.async_add_executor_job( + verification_func, user_input.get(CONF_DEVICE) + ) + except vol.Invalid: + errors[CONF_DEVICE] = ( + "invalid_ip" + if gw_type == CONF_GATEWAY_TYPE_TCP + else "invalid_serial" + ) + if CONF_PERSISTENCE_FILE in user_input: + try: + is_persistence_file(user_input[CONF_PERSISTENCE_FILE]) + except vol.Invalid: + errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file" + else: + real_persistence_path = user_input[ + CONF_PERSISTENCE_FILE + ] = self._normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE]) + for other_entry in self._async_current_entries(): + if CONF_PERSISTENCE_FILE not in other_entry.data: + continue + if real_persistence_path == self._normalize_persistence_file( + other_entry.data[CONF_PERSISTENCE_FILE] + ): + errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file" + break - for other_entry in self._async_current_entries(): - if _is_same_device(gw_type, user_input, other_entry): - errors["base"] = "already_configured" - break + for other_entry in self._async_current_entries(): + if _is_same_device(gw_type, user_input, other_entry): + errors["base"] = "already_configured" + break - # if no errors so far, try to connect - if not errors and not await try_connect(self.hass, user_input): - errors["base"] = "cannot_connect" + # if no errors so far, try to connect + if not errors and not await try_connect(self.hass, user_input): + errors["base"] = "cannot_connect" return errors diff --git a/homeassistant/components/mysensors/device.py b/homeassistant/components/mysensors/device.py index c1d8c431bc0..c066e633eaa 100644 --- a/homeassistant/components/mysensors/device.py +++ b/homeassistant/components/mysensors/device.py @@ -3,12 +3,13 @@ from __future__ import annotations from functools import partial import logging +from typing import Any from mysensors import BaseAsyncGateway, Sensor from mysensors.sensor import ChildSensor from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_OFF, STATE_ON -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import DeviceInfo, Entity @@ -36,6 +37,8 @@ MYSENSORS_PLATFORM_DEVICES = "mysensors_devices_{}" class MySensorsDevice: """Representation of a MySensors device.""" + hass: HomeAssistant + def __init__( self, gateway_id: GatewayId, @@ -51,9 +54,8 @@ class MySensorsDevice: self.child_id: int = child_id self.value_type: int = value_type # value_type as int. string variant can be looked up in gateway consts self.child_type = self._child.type - self._values = {} + self._values: dict[int, Any] = {} self._update_scheduled = False - self.hass = None @property def dev_id(self) -> DevId: diff --git a/homeassistant/components/mysensors/gateway.py b/homeassistant/components/mysensors/gateway.py index ec403e6e34b..c0a91fbdb08 100644 --- a/homeassistant/components/mysensors/gateway.py +++ b/homeassistant/components/mysensors/gateway.py @@ -66,7 +66,7 @@ def is_socket_address(value): raise vol.Invalid("Device is not a valid domain name or ip address") from err -async def try_connect(hass: HomeAssistant, user_input: dict[str, str]) -> bool: +async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool: """Try to connect to a gateway and report if it worked.""" if user_input[CONF_DEVICE] == MQTT_COMPONENT: return True # dont validate mqtt. mqtt gateways dont send ready messages :( @@ -250,7 +250,6 @@ async def _discover_persistent_devices( hass: HomeAssistant, entry: ConfigEntry, gateway: BaseAsyncGateway ): """Discover platforms for devices loaded via persistence file.""" - tasks = [] new_devices = defaultdict(list) for node_id in gateway.sensors: if not validate_node(gateway, node_id): @@ -263,8 +262,6 @@ async def _discover_persistent_devices( _LOGGER.debug("discovering persistent devices: %s", new_devices) for platform, dev_ids in new_devices.items(): discover_mysensors_platform(hass, entry.entry_id, platform, dev_ids) - if tasks: - await asyncio.wait(tasks) async def gw_stop(hass, entry: ConfigEntry, gateway: BaseAsyncGateway): @@ -331,8 +328,8 @@ def _gw_callback_factory( msg_type = msg.gateway.const.MessageType(msg.type) msg_handler: Callable[ - [Any, GatewayId, Message], Coroutine[None] - ] = HANDLERS.get(msg_type.name) + [HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None] + ] | None = HANDLERS.get(msg_type.name) if msg_handler is None: return diff --git a/homeassistant/components/mysensors/helpers.py b/homeassistant/components/mysensors/helpers.py index 9a35f67d49b..54f173de3e3 100644 --- a/homeassistant/components/mysensors/helpers.py +++ b/homeassistant/components/mysensors/helpers.py @@ -176,11 +176,15 @@ def validate_child( ) -> defaultdict[str, list[DevId]]: """Validate a child. Returns a dict mapping hass platform names to list of DevId.""" validated: defaultdict[str, list[DevId]] = defaultdict(list) - pres: IntEnum = gateway.const.Presentation - set_req: IntEnum = gateway.const.SetReq + pres: type[IntEnum] = gateway.const.Presentation + set_req: type[IntEnum] = gateway.const.SetReq child_type_name: SensorType | None = next( (member.name for member in pres if member.value == child.type), None ) + if not child_type_name: + _LOGGER.warning("Child type %s is not supported", child.type) + return validated + value_types: set[int] = {value_type} if value_type else {*child.values} value_type_names: set[ValueType] = { member.name for member in set_req if member.value in value_types @@ -199,7 +203,7 @@ def validate_child( child_value_names: set[ValueType] = { member.name for member in set_req if member.value in child.values } - v_names: set[ValueType] = platform_v_names & child_value_names + v_names = platform_v_names & child_value_names for v_name in v_names: child_schema_gen = SCHEMAS.get((platform, v_name), default_schema) diff --git a/mypy.ini b/mypy.ini index 43468d5b173..7c2dbd38ccd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1197,9 +1197,6 @@ ignore_errors = true [mypy-homeassistant.components.mullvad.*] ignore_errors = true -[mypy-homeassistant.components.mysensors.*] -ignore_errors = true - [mypy-homeassistant.components.neato.*] ignore_errors = true diff --git a/script/hassfest/mypy_config.py b/script/hassfest/mypy_config.py index 6310d0117c5..e9567be8924 100644 --- a/script/hassfest/mypy_config.py +++ b/script/hassfest/mypy_config.py @@ -127,7 +127,6 @@ IGNORED_MODULES: Final[list[str]] = [ "homeassistant.components.motion_blinds.*", "homeassistant.components.mqtt.*", "homeassistant.components.mullvad.*", - "homeassistant.components.mysensors.*", "homeassistant.components.neato.*", "homeassistant.components.ness_alarm.*", "homeassistant.components.nest.*",