diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index f469130cb1c..bac70723eeb 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -5,7 +5,13 @@ import voluptuous as vol from homeassistant.components import number from homeassistant.components.number import NumberEntity -from homeassistant.const import CONF_DEVICE, CONF_NAME, CONF_UNIQUE_ID +from homeassistant.const import ( + CONF_DEVICE, + CONF_ICON, + CONF_NAME, + CONF_OPTIMISTIC, + CONF_UNIQUE_ID, +) from homeassistant.core import callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.dispatcher import ( @@ -13,11 +19,14 @@ from homeassistant.helpers.dispatcher import ( async_dispatcher_send, ) from homeassistant.helpers.reload import async_setup_reload_service +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, HomeAssistantType from . import ( ATTR_DISCOVERY_HASH, + CONF_COMMAND_TOPIC, CONF_QOS, + CONF_STATE_TOPIC, DOMAIN, PLATFORMS, MqttAttributes, @@ -32,15 +41,16 @@ from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_NEW, clear_discovery_ _LOGGER = logging.getLogger(__name__) -CONF_TOPIC = "topic" DEFAULT_NAME = "MQTT Number" +DEFAULT_OPTIMISTIC = False PLATFORM_SCHEMA = ( - mqtt.MQTT_BASE_PLATFORM_SCHEMA.extend( + mqtt.MQTT_RW_PLATFORM_SCHEMA.extend( { vol.Optional(CONF_DEVICE): mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA, + vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, vol.Optional(CONF_UNIQUE_ID): cv.string, } ) @@ -94,16 +104,18 @@ class MqttNumber( MqttDiscoveryUpdate, MqttEntityDeviceInfo, NumberEntity, + RestoreEntity, ): """representation of an MQTT number.""" def __init__(self, config, config_entry, discovery_data): """Initialize the MQTT Number.""" self._config = config - self._unique_id = config.get(CONF_UNIQUE_ID) self._sub_state = None self._current_number = None + self._optimistic = config.get(CONF_OPTIMISTIC) + self._unique_id = config.get(CONF_UNIQUE_ID) device_config = config.get(CONF_DEVICE) @@ -145,18 +157,27 @@ class MqttNumber( except ValueError: _LOGGER.warning("We received <%s> which is not a Number", msg.payload) - self._sub_state = await subscription.async_subscribe_topics( - self.hass, - self._sub_state, - { - "state_topic": { - "topic": self._config[CONF_TOPIC], - "msg_callback": message_received, - "qos": self._config[CONF_QOS], - "encoding": None, - } - }, - ) + if self._config.get(CONF_STATE_TOPIC) is None: + # Force into optimistic mode. + self._optimistic = True + else: + self._sub_state = await subscription.async_subscribe_topics( + self.hass, + self._sub_state, + { + "state_topic": { + "topic": self._config.get(CONF_STATE_TOPIC), + "msg_callback": message_received, + "qos": self._config[CONF_QOS], + "encoding": None, + } + }, + ) + + if self._optimistic: + last_state = await self.async_get_last_state() + if last_state: + self._current_number = last_state.state async def async_will_remove_from_hass(self): """Unsubscribe when removed.""" @@ -174,20 +195,23 @@ class MqttNumber( async def async_set_value(self, value: float) -> None: """Update the current value.""" + + current_number = value + if value.is_integer(): - self._current_number = int(value) - else: - self._current_number = value + current_number = int(value) + + if self._optimistic: + self._current_number = current_number + self.async_write_ha_state() mqtt.async_publish( self.hass, - self._config[CONF_TOPIC], - self._current_number, + self._config[CONF_COMMAND_TOPIC], + current_number, self._config[CONF_QOS], ) - self.async_write_ha_state() - @property def name(self): """Return the name of this number.""" @@ -202,3 +226,13 @@ class MqttNumber( def should_poll(self): """Return the polling state.""" return False + + @property + def assumed_state(self): + """Return true if we do optimistic updates.""" + return self._optimistic + + @property + def icon(self): + """Return the icon.""" + return self._config.get(CONF_ICON) diff --git a/tests/components/mqtt/test_number.py b/tests/components/mqtt/test_number.py index dc7c7ebfe42..ac5285e9855 100644 --- a/tests/components/mqtt/test_number.py +++ b/tests/components/mqtt/test_number.py @@ -10,7 +10,8 @@ from homeassistant.components.number import ( DOMAIN as NUMBER_DOMAIN, SERVICE_SET_VALUE, ) -from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID +import homeassistant.core as ha from homeassistant.setup import async_setup_component from .test_common import ( @@ -40,7 +41,7 @@ from .test_common import ( from tests.common import async_fire_mqtt_message DEFAULT_CONFIG = { - number.DOMAIN: {"platform": "mqtt", "name": "test", "topic": "test_topic"} + number.DOMAIN: {"platform": "mqtt", "name": "test", "command_topic": "test-topic"} } @@ -50,7 +51,14 @@ async def test_run_number_setup(hass, mqtt_mock): await async_setup_component( hass, "number", - {"number": {"platform": "mqtt", "topic": topic, "name": "Test Number"}}, + { + "number": { + "platform": "mqtt", + "state_topic": topic, + "command_topic": topic, + "name": "Test Number", + } + }, ) await hass.async_block_till_done() @@ -72,12 +80,29 @@ async def test_run_number_setup(hass, mqtt_mock): async def test_run_number_service_optimistic(hass, mqtt_mock): """Test that set_value service works in optimistic mode.""" topic = "test/number" - await async_setup_component( - hass, - "number", - {"number": {"platform": "mqtt", "topic": topic, "name": "Test Number"}}, - ) - await hass.async_block_till_done() + + fake_state = ha.State("switch.test", "3") + + with patch( + "homeassistant.helpers.restore_state.RestoreEntity.async_get_last_state", + return_value=fake_state, + ): + assert await async_setup_component( + hass, + number.DOMAIN, + { + "number": { + "platform": "mqtt", + "command_topic": topic, + "name": "Test Number", + } + }, + ) + await hass.async_block_till_done() + + state = hass.states.get("number.test_number") + assert state.state == "3" + assert state.attributes.get(ATTR_ASSUMED_STATE) # Integer await hass.services.async_call( @@ -119,6 +144,40 @@ async def test_run_number_service_optimistic(hass, mqtt_mock): assert state.state == "42.1" +async def test_run_number_service(hass, mqtt_mock): + """Test that set_value service works in non optimistic mode.""" + cmd_topic = "test/number/set" + state_topic = "test/number" + + assert await async_setup_component( + hass, + number.DOMAIN, + { + "number": { + "platform": "mqtt", + "command_topic": cmd_topic, + "state_topic": state_topic, + "name": "Test Number", + } + }, + ) + await hass.async_block_till_done() + + async_fire_mqtt_message(hass, state_topic, "32") + state = hass.states.get("number.test_number") + assert state.state == "32" + + await hass.services.async_call( + NUMBER_DOMAIN, + SERVICE_SET_VALUE, + {ATTR_ENTITY_ID: "number.test_number", ATTR_VALUE: 30}, + blocking=True, + ) + mqtt_mock.async_publish.assert_called_once_with(cmd_topic, "30", 0, False) + state = hass.states.get("number.test_number") + assert state.state == "32" + + async def test_availability_when_connection_lost(hass, mqtt_mock): """Test availability after MQTT disconnection.""" await help_test_availability_when_connection_lost( @@ -189,13 +248,15 @@ async def test_unique_id(hass, mqtt_mock): { "platform": "mqtt", "name": "Test 1", - "topic": "test-topic", + "state_topic": "test-topic", + "command_topic": "test-topic", "unique_id": "TOTALLY_UNIQUE", }, { "platform": "mqtt", "name": "Test 2", - "topic": "test-topic", + "state_topic": "test-topic", + "command_topic": "test-topic", "unique_id": "TOTALLY_UNIQUE", }, ] @@ -211,8 +272,12 @@ async def test_discovery_removal_number(hass, mqtt_mock, caplog): async def test_discovery_update_number(hass, mqtt_mock, caplog): """Test update of discovered number.""" - data1 = '{ "name": "Beer", "topic": "test_topic"}' - data2 = '{ "name": "Milk", "topic": "test_topic"}' + data1 = ( + '{ "name": "Beer", "state_topic": "test-topic", "command_topic": "test-topic"}' + ) + data2 = ( + '{ "name": "Milk", "state_topic": "test-topic", "command_topic": "test-topic"}' + ) await help_test_discovery_update( hass, mqtt_mock, caplog, number.DOMAIN, data1, data2 @@ -221,7 +286,9 @@ async def test_discovery_update_number(hass, mqtt_mock, caplog): async def test_discovery_update_unchanged_number(hass, mqtt_mock, caplog): """Test update of discovered number.""" - data1 = '{ "name": "Beer", "topic": "test_topic"}' + data1 = ( + '{ "name": "Beer", "state_topic": "test-topic", "command_topic": "test-topic"}' + ) with patch( "homeassistant.components.mqtt.number.MqttNumber.discovery_update" ) as discovery_update: @@ -234,7 +301,9 @@ async def test_discovery_update_unchanged_number(hass, mqtt_mock, caplog): async def test_discovery_broken(hass, mqtt_mock, caplog): """Test handling of bad discovery message.""" data1 = '{ "name": "Beer" }' - data2 = '{ "name": "Milk", "topic": "test_topic"}' + data2 = ( + '{ "name": "Milk", "state_topic": "test-topic", "command_topic": "test-topic"}' + ) await help_test_discovery_broken( hass, mqtt_mock, caplog, number.DOMAIN, data1, data2 @@ -272,7 +341,7 @@ async def test_entity_device_info_remove(hass, mqtt_mock): async def test_entity_id_update_subscriptions(hass, mqtt_mock): """Test MQTT subscriptions are managed when entity_id is updated.""" await help_test_entity_id_update_subscriptions( - hass, mqtt_mock, number.DOMAIN, DEFAULT_CONFIG, ["test_topic"] + hass, mqtt_mock, number.DOMAIN, DEFAULT_CONFIG ) @@ -286,5 +355,5 @@ async def test_entity_id_update_discovery_update(hass, mqtt_mock): async def test_entity_debug_info_message(hass, mqtt_mock): """Test MQTT debug info.""" await help_test_entity_debug_info_message( - hass, mqtt_mock, number.DOMAIN, DEFAULT_CONFIG, "test_topic", b"ON" + hass, mqtt_mock, number.DOMAIN, DEFAULT_CONFIG, payload=b"1" )