"""Validate Modbus configuration.""" from __future__ import annotations from collections import namedtuple import logging import struct from typing import Any import voluptuous as vol from homeassistant.const import ( CONF_ADDRESS, CONF_COMMAND_OFF, CONF_COMMAND_ON, CONF_COUNT, CONF_HOST, CONF_NAME, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SLAVE, CONF_STRUCTURE, CONF_TIMEOUT, CONF_TYPE, ) from .const import ( CONF_DATA_TYPE, CONF_INPUT_TYPE, CONF_SLAVE_COUNT, CONF_SWAP, CONF_SWAP_BYTE, CONF_SWAP_NONE, CONF_WRITE_TYPE, DEFAULT_HUB, DEFAULT_SCAN_INTERVAL, PLATFORMS, SERIAL, DataType, ) _LOGGER = logging.getLogger(__name__) ENTRY = namedtuple("ENTRY", ["struct_id", "register_count"]) DEFAULT_STRUCT_FORMAT = { DataType.INT8: ENTRY("b", 1), DataType.INT16: ENTRY("h", 1), DataType.INT32: ENTRY("i", 2), DataType.INT64: ENTRY("q", 4), DataType.UINT8: ENTRY("c", 1), DataType.UINT16: ENTRY("H", 1), DataType.UINT32: ENTRY("I", 2), DataType.UINT64: ENTRY("Q", 4), DataType.FLOAT16: ENTRY("e", 1), DataType.FLOAT32: ENTRY("f", 2), DataType.FLOAT64: ENTRY("d", 4), DataType.STRING: ENTRY("s", 1), } def struct_validator(config: dict[str, Any]) -> dict[str, Any]: """Sensor schema validator.""" data_type = config[CONF_DATA_TYPE] count = config.get(CONF_COUNT, 1) name = config[CONF_NAME] structure = config.get(CONF_STRUCTURE) slave_count = config.get(CONF_SLAVE_COUNT, 0) + 1 swap_type = config.get(CONF_SWAP) if config[CONF_DATA_TYPE] != DataType.CUSTOM: if structure: error = f"{name} structure: cannot be mixed with {data_type}" raise vol.Invalid(error) if data_type not in DEFAULT_STRUCT_FORMAT: error = f"Error in sensor {name}. data_type `{data_type}` not supported" raise vol.Invalid(error) structure = f">{DEFAULT_STRUCT_FORMAT[data_type].struct_id}" if CONF_COUNT not in config: config[CONF_COUNT] = DEFAULT_STRUCT_FORMAT[data_type].register_count if slave_count > 1: structure = f">{slave_count}{DEFAULT_STRUCT_FORMAT[data_type].struct_id}" else: structure = f">{DEFAULT_STRUCT_FORMAT[data_type].struct_id}" else: if slave_count > 1: error = f"{name} structure: cannot be mixed with {CONF_SLAVE_COUNT}" raise vol.Invalid(error) if not structure: error = ( f"Error in sensor {name}. The `{CONF_STRUCTURE}` field can not be empty" ) raise vol.Invalid(error) try: size = struct.calcsize(structure) except struct.error as err: raise vol.Invalid(f"Error in {name} structure: {str(err)}") from err count = config.get(CONF_COUNT, 1) bytecount = count * 2 if bytecount != size: raise vol.Invalid( f"Structure request {size} bytes, " f"but {count} registers have a size of {bytecount} bytes" ) if swap_type != CONF_SWAP_NONE: if swap_type == CONF_SWAP_BYTE: regs_needed = 1 else: # CONF_SWAP_WORD_BYTE, CONF_SWAP_WORD regs_needed = 2 if count < regs_needed or (count % regs_needed) != 0: raise vol.Invalid( f"Error in sensor {name} swap({swap_type}) " f"not possible due to the registers " f"count: {count}, needed: {regs_needed}" ) return { **config, CONF_STRUCTURE: structure, CONF_SWAP: swap_type, } def number_validator(value: Any) -> int | float: """Coerce a value to number without losing precision.""" if isinstance(value, int): return value if isinstance(value, float): return value try: return int(value) except (TypeError, ValueError): pass try: return float(value) except (TypeError, ValueError) as err: raise vol.Invalid(f"invalid number {value}") from err def scan_interval_validator(config: dict) -> dict: """Control scan_interval.""" for hub in config: minimum_scan_interval = DEFAULT_SCAN_INTERVAL for component, conf_key in PLATFORMS: if conf_key not in hub: continue for entry in hub[conf_key]: scan_interval = entry.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) if scan_interval == 0: continue if scan_interval < 5: _LOGGER.warning( "%s %s scan_interval(%d) is lower than 5 seconds, " "which may cause Home Assistant stability issues", component, entry.get(CONF_NAME), scan_interval, ) entry[CONF_SCAN_INTERVAL] = scan_interval minimum_scan_interval = min(scan_interval, minimum_scan_interval) if ( CONF_TIMEOUT in hub and hub[CONF_TIMEOUT] > minimum_scan_interval - 1 and minimum_scan_interval > 1 ): _LOGGER.warning( "Modbus %s timeout(%d) is adjusted(%d) due to scan_interval", hub.get(CONF_NAME, ""), hub[CONF_TIMEOUT], minimum_scan_interval - 1, ) hub[CONF_TIMEOUT] = minimum_scan_interval - 1 return config def duplicate_entity_validator(config: dict) -> dict: """Control scan_interval.""" for hub_index, hub in enumerate(config): for component, conf_key in PLATFORMS: if conf_key not in hub: continue names: set[str] = set() errors: list[int] = [] addresses: set[str] = set() for index, entry in enumerate(hub[conf_key]): name = entry[CONF_NAME] addr = str(entry[CONF_ADDRESS]) if CONF_INPUT_TYPE in entry: addr += "_" + str(entry[CONF_INPUT_TYPE]) elif CONF_WRITE_TYPE in entry: addr += "_" + str(entry[CONF_WRITE_TYPE]) if CONF_COMMAND_ON in entry: addr += "_" + str(entry[CONF_COMMAND_ON]) if CONF_COMMAND_OFF in entry: addr += "_" + str(entry[CONF_COMMAND_OFF]) addr += "_" + str(entry.get(CONF_SLAVE, 0)) if addr in addresses: err = f"Modbus {component}/{name} address {addr} is duplicate, second entry not loaded!" _LOGGER.warning(err) errors.append(index) elif name in names: err = f"Modbus {component}/{name}  is duplicate, second entry not loaded!" _LOGGER.warning(err) errors.append(index) else: names.add(name) addresses.add(addr) for i in reversed(errors): del config[hub_index][conf_key][i] return config def duplicate_modbus_validator(config: list) -> list: """Control modbus connection for duplicates.""" hosts: set[str] = set() names: set[str] = set() errors = [] for index, hub in enumerate(config): name = hub.get(CONF_NAME, DEFAULT_HUB) if hub[CONF_TYPE] == SERIAL: host = hub[CONF_PORT] else: host = f"{hub[CONF_HOST]}_{hub[CONF_PORT]}" if host in hosts: err = f"Modbus {name}  contains duplicate host/port {host}, not loaded!" _LOGGER.warning(err) errors.append(index) elif name in names: err = f"Modbus {name}  is duplicate, second entry not loaded!" _LOGGER.warning(err) errors.append(index) else: hosts.add(host) names.add(name) for i in reversed(errors): del config[i] return config