Fix mysensors typing (#51518)

* Fix device

* Fix init

* Fix gateway

* Fix config flow

* Fix helpers

* Remove mysensors from typing ignore list
pull/51521/head
Martin Hjelmare 2021-06-05 13:43:39 +02:00 committed by GitHub
parent 7a6d067eb4
commit e73cdfab2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 76 deletions

View File

@ -42,7 +42,7 @@ from .const import (
DevId,
SensorType,
)
from .device import MySensorsDevice, MySensorsEntity, get_mysensors_devices
from .device import MySensorsDevice, get_mysensors_devices
from .gateway import finish_setup, get_mysensors_gateway, gw_stop, setup_gateway
from .helpers import on_unload
@ -271,7 +271,7 @@ def setup_mysensors_platform(
hass: HomeAssistant,
domain: str, # hass platform name
discovery_info: dict[str, list[DevId]],
device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsEntity]],
device_class: type[MySensorsDevice] | dict[SensorType, type[MySensorsDevice]],
device_args: (
None | tuple
) = None, # extra arguments that will be given to the entity constructor
@ -302,11 +302,13 @@ def setup_mysensors_platform(
if not gateway:
_LOGGER.warning("Skipping setup of %s, no gateway found", dev_id)
continue
device_class_copy = device_class
if isinstance(device_class, dict):
child = gateway.sensors[node_id].children[child_id]
s_type = gateway.const.Presentation(child.type).name
device_class_copy = device_class[s_type]
else:
device_class_copy = device_class
args_copy = (*device_args, gateway_id, gateway, node_id, child_id, value_type)
devices[dev_id] = device_class_copy(*args_copy)

View File

@ -27,7 +27,7 @@ from homeassistant.components.mysensors import (
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import callback
from homeassistant.data_entry_flow import RESULT_TYPE_FORM, FlowResult
from homeassistant.data_entry_flow import FlowResult
import homeassistant.helpers.config_validation as cv
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
@ -111,7 +111,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Set up config flow."""
self._gw_type: str | None = None
async def async_step_import(self, user_input: dict[str, str] | None = None):
async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
"""Import a config entry.
This method is called by async_setup and it has already
@ -131,12 +131,14 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
else:
user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL
result: dict[str, Any] = await self.async_step_user(user_input=user_input)
if result["type"] == RESULT_TYPE_FORM:
return self.async_abort(reason=next(iter(result["errors"].values())))
result: FlowResult = await self.async_step_user(user_input=user_input)
if errors := result.get("errors"):
return self.async_abort(reason=next(iter(errors.values())))
return result
async def async_step_user(self, user_input: dict[str, str] | None = None):
async def async_step_user(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Create a config entry from frontend user input."""
schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)}
schema = vol.Schema(schema)
@ -158,9 +160,11 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form(step_id="user", data_schema=schema, errors=errors)
async def async_step_gw_serial(self, user_input: dict[str, str] | None = None):
async def async_step_gw_serial(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create config entry for a serial gateway."""
errors = {}
errors: dict[str, str] = {}
if user_input is not None:
errors.update(
await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input)
@ -187,7 +191,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
step_id="gw_serial", data_schema=schema, errors=errors
)
async def async_step_gw_tcp(self, user_input: dict[str, str] | None = None):
async def async_step_gw_tcp(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create a config entry for a tcp gateway."""
errors = {}
if user_input is not None:
@ -225,7 +231,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
return True
return False
async def async_step_gw_mqtt(self, user_input: dict[str, str] | None = None):
async def async_step_gw_mqtt(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Create a config entry for a mqtt gateway."""
errors = {}
if user_input is not None:
@ -280,9 +288,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
)
@callback
def _async_create_entry(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
def _async_create_entry(self, user_input: dict[str, Any]) -> FlowResult:
"""Create the config entry."""
return self.async_create_entry(
title=f"{user_input[CONF_DEVICE]}",
@ -296,55 +302,52 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self,
gw_type: ConfGatewayType,
errors: dict[str, str],
user_input: dict[str, str] | None = None,
user_input: dict[str, Any],
) -> dict[str, str]:
"""Validate parameters common to all gateway types."""
if user_input is not None:
errors.update(_validate_version(user_input.get(CONF_VERSION)))
errors.update(_validate_version(user_input[CONF_VERSION]))
if gw_type != CONF_GATEWAY_TYPE_MQTT:
if gw_type == CONF_GATEWAY_TYPE_TCP:
verification_func = is_socket_address
else:
verification_func = is_serial_port
if gw_type != CONF_GATEWAY_TYPE_MQTT:
if gw_type == CONF_GATEWAY_TYPE_TCP:
verification_func = is_socket_address
else:
verification_func = is_serial_port
try:
await self.hass.async_add_executor_job(
verification_func, user_input.get(CONF_DEVICE)
)
except vol.Invalid:
errors[CONF_DEVICE] = (
"invalid_ip"
if gw_type == CONF_GATEWAY_TYPE_TCP
else "invalid_serial"
)
if CONF_PERSISTENCE_FILE in user_input:
try:
is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
except vol.Invalid:
errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
else:
real_persistence_path = user_input[
CONF_PERSISTENCE_FILE
] = self._normalize_persistence_file(
user_input[CONF_PERSISTENCE_FILE]
)
for other_entry in self._async_current_entries():
if CONF_PERSISTENCE_FILE not in other_entry.data:
continue
if real_persistence_path == self._normalize_persistence_file(
other_entry.data[CONF_PERSISTENCE_FILE]
):
errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
break
try:
await self.hass.async_add_executor_job(
verification_func, user_input.get(CONF_DEVICE)
)
except vol.Invalid:
errors[CONF_DEVICE] = (
"invalid_ip"
if gw_type == CONF_GATEWAY_TYPE_TCP
else "invalid_serial"
)
if CONF_PERSISTENCE_FILE in user_input:
try:
is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
except vol.Invalid:
errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
else:
real_persistence_path = user_input[
CONF_PERSISTENCE_FILE
] = self._normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE])
for other_entry in self._async_current_entries():
if CONF_PERSISTENCE_FILE not in other_entry.data:
continue
if real_persistence_path == self._normalize_persistence_file(
other_entry.data[CONF_PERSISTENCE_FILE]
):
errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
break
for other_entry in self._async_current_entries():
if _is_same_device(gw_type, user_input, other_entry):
errors["base"] = "already_configured"
break
for other_entry in self._async_current_entries():
if _is_same_device(gw_type, user_input, other_entry):
errors["base"] = "already_configured"
break
# if no errors so far, try to connect
if not errors and not await try_connect(self.hass, user_input):
errors["base"] = "cannot_connect"
# if no errors so far, try to connect
if not errors and not await try_connect(self.hass, user_input):
errors["base"] = "cannot_connect"
return errors

View File

@ -3,12 +3,13 @@ from __future__ import annotations
from functools import partial
import logging
from typing import Any
from mysensors import BaseAsyncGateway, Sensor
from mysensors.sensor import ChildSensor
from homeassistant.const import ATTR_BATTERY_LEVEL, STATE_OFF, STATE_ON
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo, Entity
@ -36,6 +37,8 @@ MYSENSORS_PLATFORM_DEVICES = "mysensors_devices_{}"
class MySensorsDevice:
"""Representation of a MySensors device."""
hass: HomeAssistant
def __init__(
self,
gateway_id: GatewayId,
@ -51,9 +54,8 @@ class MySensorsDevice:
self.child_id: int = child_id
self.value_type: int = value_type # value_type as int. string variant can be looked up in gateway consts
self.child_type = self._child.type
self._values = {}
self._values: dict[int, Any] = {}
self._update_scheduled = False
self.hass = None
@property
def dev_id(self) -> DevId:

View File

@ -66,7 +66,7 @@ def is_socket_address(value):
raise vol.Invalid("Device is not a valid domain name or ip address") from err
async def try_connect(hass: HomeAssistant, user_input: dict[str, str]) -> bool:
async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
@ -250,7 +250,6 @@ async def _discover_persistent_devices(
hass: HomeAssistant, entry: ConfigEntry, gateway: BaseAsyncGateway
):
"""Discover platforms for devices loaded via persistence file."""
tasks = []
new_devices = defaultdict(list)
for node_id in gateway.sensors:
if not validate_node(gateway, node_id):
@ -263,8 +262,6 @@ async def _discover_persistent_devices(
_LOGGER.debug("discovering persistent devices: %s", new_devices)
for platform, dev_ids in new_devices.items():
discover_mysensors_platform(hass, entry.entry_id, platform, dev_ids)
if tasks:
await asyncio.wait(tasks)
async def gw_stop(hass, entry: ConfigEntry, gateway: BaseAsyncGateway):
@ -331,8 +328,8 @@ def _gw_callback_factory(
msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[
[Any, GatewayId, Message], Coroutine[None]
] = HANDLERS.get(msg_type.name)
[HomeAssistant, GatewayId, Message], Coroutine[Any, Any, None]
] | None = HANDLERS.get(msg_type.name)
if msg_handler is None:
return

View File

@ -176,11 +176,15 @@ def validate_child(
) -> defaultdict[str, list[DevId]]:
"""Validate a child. Returns a dict mapping hass platform names to list of DevId."""
validated: defaultdict[str, list[DevId]] = defaultdict(list)
pres: IntEnum = gateway.const.Presentation
set_req: IntEnum = gateway.const.SetReq
pres: type[IntEnum] = gateway.const.Presentation
set_req: type[IntEnum] = gateway.const.SetReq
child_type_name: SensorType | None = next(
(member.name for member in pres if member.value == child.type), None
)
if not child_type_name:
_LOGGER.warning("Child type %s is not supported", child.type)
return validated
value_types: set[int] = {value_type} if value_type else {*child.values}
value_type_names: set[ValueType] = {
member.name for member in set_req if member.value in value_types
@ -199,7 +203,7 @@ def validate_child(
child_value_names: set[ValueType] = {
member.name for member in set_req if member.value in child.values
}
v_names: set[ValueType] = platform_v_names & child_value_names
v_names = platform_v_names & child_value_names
for v_name in v_names:
child_schema_gen = SCHEMAS.get((platform, v_name), default_schema)

View File

@ -1197,9 +1197,6 @@ ignore_errors = true
[mypy-homeassistant.components.mullvad.*]
ignore_errors = true
[mypy-homeassistant.components.mysensors.*]
ignore_errors = true
[mypy-homeassistant.components.neato.*]
ignore_errors = true

View File

@ -127,7 +127,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.motion_blinds.*",
"homeassistant.components.mqtt.*",
"homeassistant.components.mullvad.*",
"homeassistant.components.mysensors.*",
"homeassistant.components.neato.*",
"homeassistant.components.ness_alarm.*",
"homeassistant.components.nest.*",