Fix mysensors typing (#51518)
* Fix device * Fix init * Fix gateway * Fix config flow * Fix helpers * Remove mysensors from typing ignore listpull/51521/head
parent
7a6d067eb4
commit
e73cdfab2f
|
@ -42,7 +42,7 @@ from .const import (
|
|||
DevId,
|
||||
SensorType,
|
||||
)
|
||||
from .device import MySensorsDevice, MySensorsEntity, get_mysensors_devices
|
||||
from .device import MySensorsDevice, get_mysensors_devices
|
||||
from .gateway import finish_setup, get_mysensors_gateway, gw_stop, setup_gateway
|
||||
from .helpers import on_unload
|
||||
|
||||
|
@ -271,7 +271,7 @@ def setup_mysensors_platform(
|
|||
hass: HomeAssistant,
|
||||
domain: str, # hass platform name
|
||||
discovery_info: dict[str, list[DevId]],
|
||||
device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsEntity]],
|
||||
device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsDevice]],
|
||||
device_args: (
|
||||
None | tuple
|
||||
) = None, # extra arguments that will be given to the entity constructor
|
||||
|
@ -302,11 +302,13 @@ def setup_mysensors_platform(
|
|||
if not gateway:
|
||||
_LOGGER.warning("Skipping setup of %s, no gateway found", dev_id)
|
||||
continue
|
||||
device_class_copy = device_class
|
||||
|
||||
if isinstance(device_class, dict):
|
||||
child = gateway.sensors[node_id].children[child_id]
|
||||
s_type = gateway.const.Presentation(child.type).name
|
||||
device_class_copy = device_class[s_type]
|
||||
else:
|
||||
device_class_copy = device_class
|
||||
|
||||
args_copy = (*device_args, gateway_id, gateway, node_id, child_id, value_type)
|
||||
devices[dev_id] = device_class_copy(*args_copy)
|
||||
|
|
|
@ -27,7 +27,7 @@ from homeassistant.components.mysensors import (
|
|||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import RESULT_TYPE_FORM, FlowResult
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
|
||||
|
@ -111,7 +111,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
"""Set up config flow."""
|
||||
self._gw_type: str | None = None
|
||||
|
||||
async def async_step_import(self, user_input: dict[str, str] | None = None):
|
||||
async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
|
||||
"""Import a config entry.
|
||||
|
||||
This method is called by async_setup and it has already
|
||||
|
@ -131,12 +131,14 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
else:
|
||||
user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL
|
||||
|
||||
result: dict[str, Any] = await self.async_step_user(user_input=user_input)
|
||||
if result["type"] == RESULT_TYPE_FORM:
|
||||
return self.async_abort(reason=next(iter(result["errors"].values())))
|
||||
result: FlowResult = await self.async_step_user(user_input=user_input)
|
||||
if errors := result.get("errors"):
|
||||
return self.async_abort(reason=next(iter(errors.values())))
|
||||
return result
|
||||
|
||||
async def async_step_user(self, user_input: dict[str, str] | None = None):
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
"""Create a config entry from frontend user input."""
|
||||
schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)}
|
||||
schema = vol.Schema(schema)
|
||||
|
@ -158,9 +160,11 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
return self.async_show_form(step_id="user", data_schema=schema, errors=errors)
|
||||
|
||||
async def async_step_gw_serial(self, user_input: dict[str, str] | None = None):
|
||||
async def async_step_gw_serial(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Create config entry for a serial gateway."""
|
||||
errors = {}
|
||||
errors: dict[str, str] = {}
|
||||
if user_input is not None:
|
||||
errors.update(
|
||||
await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input)
|
||||
|
@ -187,7 +191,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
step_id="gw_serial", data_schema=schema, errors=errors
|
||||
)
|
||||
|
||||
async def async_step_gw_tcp(self, user_input: dict[str, str] | None = None):
|
||||
async def async_step_gw_tcp(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Create a config entry for a tcp gateway."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
|
@ -225,7 +231,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def async_step_gw_mqtt(self, user_input: dict[str, str] | None = None):
|
||||
async def async_step_gw_mqtt(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Create a config entry for a mqtt gateway."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
|
@ -280,9 +288,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
)
|
||||
|
||||
@callback
|
||||
def _async_create_entry(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> FlowResult:
|
||||
def _async_create_entry(self, user_input: dict[str, Any]) -> FlowResult:
|
||||
"""Create the config entry."""
|
||||
return self.async_create_entry(
|
||||
title=f"{user_input[CONF_DEVICE]}",
|
||||
|
@ -296,55 +302,52 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
self,
|
||||
gw_type: ConfGatewayType,
|
||||
errors: dict[str, str],
|
||||
user_input: dict[str, str] | None = None,
|
||||
user_input: dict[str, Any],
|
||||
) -> dict[str, str]:
|
||||
"""Validate parameters common to all gateway types."""
|
||||
if user_input is not None:
|
||||
errors.update(_validate_version(user_input.get(CONF_VERSION)))
|
||||
errors.update(_validate_version(user_input[CONF_VERSION]))
|
||||
|
||||
if gw_type != CONF_GATEWAY_TYPE_MQTT:
|
||||
if gw_type == CONF_GATEWAY_TYPE_TCP:
|
||||
verification_func = is_socket_address
|
||||
else:
|
||||
verification_func = is_serial_port
|
||||
if gw_type != CONF_GATEWAY_TYPE_MQTT:
|
||||
if gw_type == CONF_GATEWAY_TYPE_TCP:
|
||||
verification_func = is_socket_address
|
||||
else:
|
||||
verification_func = is_serial_port
|
||||
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
verification_func, user_input.get(CONF_DEVICE)
|
||||
)
|
||||
except vol.Invalid:
|
||||
errors[CONF_DEVICE] = (
|
||||
"invalid_ip"
|
||||
if gw_type == CONF_GATEWAY_TYPE_TCP
|
||||
else "invalid_serial"
|
||||
)
|
||||
if CONF_PERSISTENCE_FILE in user_input:
|
||||
try:
|
||||
is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
|
||||
except vol.Invalid:
|
||||
errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
|
||||
else:
|
||||
real_persistence_path = user_input[
|
||||
CONF_PERSISTENCE_FILE
|
||||
] = self._normalize_persistence_file(
|
||||
user_input[CONF_PERSISTENCE_FILE]
|
||||
)
|
||||
for other_entry in self._async_current_entries():
|
||||
if CONF_PERSISTENCE_FILE not in other_entry.data:
|
||||
continue
|
||||
if real_persistence_path == self._normalize_persistence_file(
|
||||
other_entry.data[CONF_PERSISTENCE_FILE]
|
||||
):
|
||||
errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
|
||||
break
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
verification_func, user_input.get(CONF_DEVICE)
|
||||
)
|
||||
except vol.Invalid:
|
||||
errors[CONF_DEVICE] = (
|
||||
"invalid_ip"
|
||||
if gw_type == CONF_GATEWAY_TYPE_TCP
|
||||
else "invalid_serial"
|
||||
)
|
||||
if CONF_PERSISTENCE_FILE in user_input:
|
||||
try:
|
||||
is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
|
||||
except vol.Invalid:
|
||||
errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
|
||||
else:
|
||||
real_persistence_path = user_input[
|
||||
CONF_PERSISTENCE_FILE
|
||||
] = self._normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE])
|
||||
for other_entry in self._async_current_entries():
|
||||
if CONF_PERSISTENCE_FILE not in other_entry.data:
|
||||
continue
|
||||
if real_persistence_path == self._normalize_persistence_file(
|
||||
other_entry.data[CONF_PERSISTENCE_FILE]
|
||||
):
|
||||
errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
|
||||
break
|
||||
|
||||
for other_entry in self._async_current_entries():
|
||||
if _is_same_device(gw_type, user_input, other_entry):
|
||||
errors["base"] = "already_configured"
|
||||
break
|
||||
for other_entry in self._async_current_entries():
|
||||
if _is_same_device(gw_type, user_input, other_entry):
|
||||
errors["base"] = "already_configured"
|
||||
break
|
||||
|
||||
# if no errors so far, try to connect
|
||||
if not errors and not await try_connect(self.hass, user_input):
|
||||
errors["base"] = "cannot_connect"
|
||||
# if no errors so far, try to connect
|
||||
if not errors and not await try_connect(self.hass, user_input):
|
||||
errors["base"] = "cannot_connect"
|
||||
|
||||
return errors
|
||||
|
|
|
@ -3,12 +3,13 @@ from __future__ import annotations
|
|||
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from mysensors import BaseAsyncGateway, Sensor
|
||||
from mysensors.sensor import ChildSensor
|
||||
|
||||
from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_OFF, STATE_ON
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity import DeviceInfo, Entity
|
||||
|
||||
|
@ -36,6 +37,8 @@ MYSENSORS_PLATFORM_DEVICES = "mysensors_devices_{}"
|
|||
class MySensorsDevice:
|
||||
"""Representation of a MySensors device."""
|
||||
|
||||
hass: HomeAssistant
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gateway_id: GatewayId,
|
||||
|
@ -51,9 +54,8 @@ class MySensorsDevice:
|
|||
self.child_id: int = child_id
|
||||
self.value_type: int = value_type # value_type as int. string variant can be looked up in gateway consts
|
||||
self.child_type = self._child.type
|
||||
self._values = {}
|
||||
self._values: dict[int, Any] = {}
|
||||
self._update_scheduled = False
|
||||
self.hass = None
|
||||
|
||||
@property
|
||||
def dev_id(self) -> DevId:
|
||||
|
|
|
@ -66,7 +66,7 @@ def is_socket_address(value):
|
|||
raise vol.Invalid("Device is not a valid domain name or ip address") from err
|
||||
|
||||
|
||||
async def try_connect(hass: HomeAssistant, user_input: dict[str, str]) -> bool:
|
||||
async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool:
|
||||
"""Try to connect to a gateway and report if it worked."""
|
||||
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
|
||||
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
|
||||
|
@ -250,7 +250,6 @@ async def _discover_persistent_devices(
|
|||
hass: HomeAssistant, entry: ConfigEntry, gateway: BaseAsyncGateway
|
||||
):
|
||||
"""Discover platforms for devices loaded via persistence file."""
|
||||
tasks = []
|
||||
new_devices = defaultdict(list)
|
||||
for node_id in gateway.sensors:
|
||||
if not validate_node(gateway, node_id):
|
||||
|
@ -263,8 +262,6 @@ async def _discover_persistent_devices(
|
|||
_LOGGER.debug("discovering persistent devices: %s", new_devices)
|
||||
for platform, dev_ids in new_devices.items():
|
||||
discover_mysensors_platform(hass, entry.entry_id, platform, dev_ids)
|
||||
if tasks:
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
|
||||
async def gw_stop(hass, entry: ConfigEntry, gateway: BaseAsyncGateway):
|
||||
|
@ -331,8 +328,8 @@ def _gw_callback_factory(
|
|||
|
||||
msg_type = msg.gateway.const.MessageType(msg.type)
|
||||
msg_handler: Callable[
|
||||
[Any, GatewayId, Message], Coroutine[None]
|
||||
] = HANDLERS.get(msg_type.name)
|
||||
[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
|
||||
] | None = HANDLERS.get(msg_type.name)
|
||||
|
||||
if msg_handler is None:
|
||||
return
|
||||
|
|
|
@ -176,11 +176,15 @@ def validate_child(
|
|||
) -> defaultdict[str, list[DevId]]:
|
||||
"""Validate a child. Returns a dict mapping hass platform names to list of DevId."""
|
||||
validated: defaultdict[str, list[DevId]] = defaultdict(list)
|
||||
pres: IntEnum = gateway.const.Presentation
|
||||
set_req: IntEnum = gateway.const.SetReq
|
||||
pres: type[IntEnum] = gateway.const.Presentation
|
||||
set_req: type[IntEnum] = gateway.const.SetReq
|
||||
child_type_name: SensorType | None = next(
|
||||
(member.name for member in pres if member.value == child.type), None
|
||||
)
|
||||
if not child_type_name:
|
||||
_LOGGER.warning("Child type %s is not supported", child.type)
|
||||
return validated
|
||||
|
||||
value_types: set[int] = {value_type} if value_type else {*child.values}
|
||||
value_type_names: set[ValueType] = {
|
||||
member.name for member in set_req if member.value in value_types
|
||||
|
@ -199,7 +203,7 @@ def validate_child(
|
|||
child_value_names: set[ValueType] = {
|
||||
member.name for member in set_req if member.value in child.values
|
||||
}
|
||||
v_names: set[ValueType] = platform_v_names & child_value_names
|
||||
v_names = platform_v_names & child_value_names
|
||||
|
||||
for v_name in v_names:
|
||||
child_schema_gen = SCHEMAS.get((platform, v_name), default_schema)
|
||||
|
|
3
mypy.ini
3
mypy.ini
|
@ -1197,9 +1197,6 @@ ignore_errors = true
|
|||
[mypy-homeassistant.components.mullvad.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.mysensors.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.neato.*]
|
||||
ignore_errors = true
|
||||
|
||||
|
|
|
@ -127,7 +127,6 @@ IGNORED_MODULES: Final[list[str]] = [
|
|||
"homeassistant.components.motion_blinds.*",
|
||||
"homeassistant.components.mqtt.*",
|
||||
"homeassistant.components.mullvad.*",
|
||||
"homeassistant.components.mysensors.*",
|
||||
"homeassistant.components.neato.*",
|
||||
"homeassistant.components.ness_alarm.*",
|
||||
"homeassistant.components.nest.*",
|
||||
|
|
Loading…
Reference in New Issue