From baa3b15dbc7cef6f6e3b765b284fcf58c3eaacdd Mon Sep 17 00:00:00 2001 From: Claudio Ruggeri - CR-Tech <41435902+crug80@users.noreply.github.com> Date: Sat, 22 Feb 2025 04:16:15 +0100 Subject: [PATCH] Fix write_registers calling after the upgrade of pymodbus to 3.8.x (#139017) --- homeassistant/components/modbus/modbus.py | 5 +++++ tests/components/modbus/conftest.py | 1 + tests/components/modbus/test_switch.py | 23 +++++++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/homeassistant/components/modbus/modbus.py b/homeassistant/components/modbus/modbus.py index 81cfc3127d1..006ef504590 100644 --- a/homeassistant/components/modbus/modbus.py +++ b/homeassistant/components/modbus/modbus.py @@ -384,6 +384,11 @@ class ModbusHub: {ATTR_SLAVE: slave} if slave is not None else {ATTR_SLAVE: 1} ) entry = self._pb_request[use_call] + + if use_call in {"write_registers", "write_coils"}: + if not isinstance(value, list): + value = [value] + kwargs[entry.value_attr_name] = value try: result: ModbusPDU = await entry.func(address, **kwargs) diff --git a/tests/components/modbus/conftest.py b/tests/components/modbus/conftest.py index 0a2cbf44b9e..a35cc95605d 100644 --- a/tests/components/modbus/conftest.py +++ b/tests/components/modbus/conftest.py @@ -42,6 +42,7 @@ class ReadResult: self.registers = register_words self.bits = register_words self.value = register_words + self.count = len(register_words) if register_words is not None else 0 def isError(self): """Set error state.""" diff --git a/tests/components/modbus/test_switch.py b/tests/components/modbus/test_switch.py index 4b2c123ba75..fc994c70d49 100644 --- a/tests/components/modbus/test_switch.py +++ b/tests/components/modbus/test_switch.py @@ -12,6 +12,7 @@ from homeassistant.components.modbus.const import ( CALL_TYPE_DISCRETE, CALL_TYPE_REGISTER_HOLDING, CALL_TYPE_REGISTER_INPUT, + CALL_TYPE_X_REGISTER_HOLDINGS, CONF_DEVICE_ADDRESS, CONF_INPUT_TYPE, CONF_STATE_OFF, @@ -50,6 +51,7 @@ from tests.common import async_fire_time_changed ENTITY_ID = f"{SWITCH_DOMAIN}.{TEST_ENTITY_NAME}".replace(" ", "_") ENTITY_ID2 = f"{ENTITY_ID}_2" ENTITY_ID3 = f"{ENTITY_ID}_3" +ENTITY_ID4 = f"{ENTITY_ID}_4" @pytest.mark.parametrize( @@ -330,6 +332,13 @@ async def test_restore_state_switch( CONF_SCAN_INTERVAL: 0, CONF_VERIFY: {CONF_STATE_ON: [1, 3]}, }, + { + CONF_NAME: f"{TEST_ENTITY_NAME} 4", + CONF_ADDRESS: 19, + CONF_WRITE_TYPE: CALL_TYPE_X_REGISTER_HOLDINGS, + CONF_SCAN_INTERVAL: 0, + CONF_VERIFY: {CONF_STATE_ON: [1, 3]}, + }, ], }, ], @@ -381,6 +390,20 @@ async def test_switch_service_turn( await hass.async_block_till_done() assert hass.states.get(ENTITY_ID3).state == STATE_OFF + mock_modbus.read_holding_registers.return_value = ReadResult([0x03]) + assert hass.states.get(ENTITY_ID4).state == STATE_OFF + await hass.services.async_call( + SWITCH_DOMAIN, SERVICE_TURN_ON, service_data={ATTR_ENTITY_ID: ENTITY_ID4} + ) + await hass.async_block_till_done() + assert hass.states.get(ENTITY_ID4).state == STATE_ON + mock_modbus.read_holding_registers.return_value = ReadResult([0x00]) + await hass.services.async_call( + SWITCH_DOMAIN, SERVICE_TURN_OFF, service_data={ATTR_ENTITY_ID: ENTITY_ID4} + ) + await hass.async_block_till_done() + assert hass.states.get(ENTITY_ID4).state == STATE_OFF + mock_modbus.write_register.side_effect = ModbusException("fail write_") await hass.services.async_call( SWITCH_DOMAIN, SERVICE_TURN_ON, service_data={ATTR_ENTITY_ID: ENTITY_ID2}