diff --git a/homeassistant/components/samsungtv/media_player.py b/homeassistant/components/samsungtv/media_player.py index e7153a7f5d4..aca54838a99 100644 --- a/homeassistant/components/samsungtv/media_player.py +++ b/homeassistant/components/samsungtv/media_player.py @@ -98,7 +98,27 @@ class SamsungTVDevice(MediaPlayerDevice): def update(self): """Update state of device.""" - self.send_key("KEY") + if self._power_off_in_progress(): + self._state = STATE_OFF + else: + if self._remote is not None: + # Close the current remote connection + self._remote.close() + self._remote = None + + try: + self.get_remote() + if self._remote: + self._state = STATE_ON + except ( + samsung_exceptions.UnhandledResponse, + samsung_exceptions.AccessDenied, + ): + # We got a response so it's working. + self._state = STATE_ON + except (OSError, WebSocketException): + # Different reasons, e.g. hostname not resolveable + self._state = STATE_OFF def get_remote(self): """Create or return a remote control instance.""" @@ -128,19 +148,12 @@ class SamsungTVDevice(MediaPlayerDevice): # BrokenPipe can occur when the commands is sent to fast # WebSocketException can occur when timed out self._remote = None - self._state = STATE_ON except (samsung_exceptions.UnhandledResponse, samsung_exceptions.AccessDenied): # We got a response so it's on. - self._state = STATE_ON - self._remote = None LOGGER.debug("Failed sending command %s", key, exc_info=True) - return except OSError: # Different reasons, e.g. hostname not resolveable - self._state = STATE_OFF - self._remote = None - if self._power_off_in_progress(): - self._state = STATE_OFF + pass def _power_off_in_progress(self): return ( diff --git a/tests/components/samsungtv/test_media_player.py b/tests/components/samsungtv/test_media_player.py index 3afedda746e..2b9f379515d 100644 --- a/tests/components/samsungtv/test_media_player.py +++ b/tests/components/samsungtv/test_media_player.py @@ -135,16 +135,36 @@ async def test_update_on(hass, remote, mock_now): async def test_update_off(hass, remote, mock_now): """Testing update tv off.""" - await setup_samsungtv(hass, MOCK_CONFIG) - remote.control = mock.Mock(side_effect=OSError("Boom")) + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote", + side_effect=[OSError("Boom"), mock.DEFAULT], + ), patch("homeassistant.components.samsungtv.config_flow.socket"): + await setup_samsungtv(hass, MOCK_CONFIG) - next_update = mock_now + timedelta(minutes=5) - with patch("homeassistant.util.dt.utcnow", return_value=next_update): - async_fire_time_changed(hass, next_update) - await hass.async_block_till_done() + next_update = mock_now + timedelta(minutes=5) + with patch("homeassistant.util.dt.utcnow", return_value=next_update): + async_fire_time_changed(hass, next_update) + await hass.async_block_till_done() - state = hass.states.get(ENTITY_ID) - assert state.state == STATE_OFF + state = hass.states.get(ENTITY_ID) + assert state.state == STATE_OFF + + +async def test_update_unhandled_response(hass, remote, mock_now): + """Testing update tv unhandled response exception.""" + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote", + side_effect=[exceptions.UnhandledResponse("Boom"), mock.DEFAULT], + ), patch("homeassistant.components.samsungtv.config_flow.socket"): + await setup_samsungtv(hass, MOCK_CONFIG) + + next_update = mock_now + timedelta(minutes=5) + with patch("homeassistant.util.dt.utcnow", return_value=next_update): + async_fire_time_changed(hass, next_update) + await hass.async_block_till_done() + + state = hass.states.get(ENTITY_ID) + assert state.state == STATE_ON async def test_send_key(hass, remote): @@ -155,8 +175,10 @@ async def test_send_key(hass, remote): ) state = hass.states.get(ENTITY_ID) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_VOLUP"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_VOLUP")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] assert state.state == STATE_ON @@ -182,12 +204,13 @@ async def test_send_key_connection_closed_retry_succeed(hass, remote): ) state = hass.states.get(ENTITY_ID) # key because of retry two times and update called - assert remote.control.call_count == 3 + assert remote.control.call_count == 2 assert remote.control.call_args_list == [ call("KEY_VOLUP"), call("KEY_VOLUP"), - call("KEY"), ] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] assert state.state == STATE_ON @@ -221,7 +244,7 @@ async def test_send_key_os_error(hass, remote): DOMAIN, SERVICE_VOLUME_UP, {ATTR_ENTITY_ID: ENTITY_ID}, True ) state = hass.states.get(ENTITY_ID) - assert state.state == STATE_OFF + assert state.state == STATE_ON async def test_name(hass, remote): @@ -336,8 +359,10 @@ async def test_volume_up(hass, remote): DOMAIN, SERVICE_VOLUME_UP, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_VOLUP"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_VOLUP")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_volume_down(hass, remote): @@ -347,8 +372,10 @@ async def test_volume_down(hass, remote): DOMAIN, SERVICE_VOLUME_DOWN, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_VOLDOWN"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_VOLDOWN")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_mute_volume(hass, remote): @@ -361,8 +388,10 @@ async def test_mute_volume(hass, remote): True, ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_MUTE"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_MUTE")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_media_play(hass, remote): @@ -372,8 +401,10 @@ async def test_media_play(hass, remote): DOMAIN, SERVICE_MEDIA_PLAY, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_PLAY"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_PLAY")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_media_pause(hass, remote): @@ -383,8 +414,10 @@ async def test_media_pause(hass, remote): DOMAIN, SERVICE_MEDIA_PAUSE, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_PAUSE"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_PAUSE")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_media_next_track(hass, remote): @@ -394,8 +427,10 @@ async def test_media_next_track(hass, remote): DOMAIN, SERVICE_MEDIA_NEXT_TRACK, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_CHUP"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_CHUP")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_media_previous_track(hass, remote): @@ -405,8 +440,10 @@ async def test_media_previous_track(hass, remote): DOMAIN, SERVICE_MEDIA_PREVIOUS_TRACK, {ATTR_ENTITY_ID: ENTITY_ID}, True ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_CHDOWN"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_CHDOWN")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_turn_on_with_turnon(hass, remote, delay): @@ -450,71 +487,84 @@ async def test_play_media(hass, remote): True, ) # keys and update called - assert remote.control.call_count == 5 + assert remote.control.call_count == 4 assert remote.control.call_args_list == [ call("KEY_5"), call("KEY_7"), call("KEY_6"), call("KEY_ENTER"), - call("KEY"), ] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] assert len(sleeps) == 3 async def test_play_media_invalid_type(hass, remote): """Test for play_media with invalid media type.""" - url = "https://example.com" - await setup_samsungtv(hass, MOCK_CONFIG) - assert await hass.services.async_call( - DOMAIN, - SERVICE_PLAY_MEDIA, - { - ATTR_ENTITY_ID: ENTITY_ID, - ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_URL, - ATTR_MEDIA_CONTENT_ID: url, - }, - True, - ) - # only update called - assert remote.control.call_count == 1 - assert remote.control.call_args_list == [call("KEY")] + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote" + ) as remote, patch("homeassistant.components.samsungtv.config_flow.socket"): + url = "https://example.com" + await setup_samsungtv(hass, MOCK_CONFIG) + assert await hass.services.async_call( + DOMAIN, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: ENTITY_ID, + ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_URL, + ATTR_MEDIA_CONTENT_ID: url, + }, + True, + ) + # only update called + assert remote.control.call_count == 0 + assert remote.close.call_count == 0 + assert remote.call_count == 1 async def test_play_media_channel_as_string(hass, remote): """Test for play_media with invalid channel as string.""" - url = "https://example.com" - await setup_samsungtv(hass, MOCK_CONFIG) - assert await hass.services.async_call( - DOMAIN, - SERVICE_PLAY_MEDIA, - { - ATTR_ENTITY_ID: ENTITY_ID, - ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_CHANNEL, - ATTR_MEDIA_CONTENT_ID: url, - }, - True, - ) - # only update called - assert remote.control.call_count == 1 - assert remote.control.call_args_list == [call("KEY")] + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote" + ) as remote, patch("homeassistant.components.samsungtv.config_flow.socket"): + url = "https://example.com" + await setup_samsungtv(hass, MOCK_CONFIG) + assert await hass.services.async_call( + DOMAIN, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: ENTITY_ID, + ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_CHANNEL, + ATTR_MEDIA_CONTENT_ID: url, + }, + True, + ) + # only update called + assert remote.control.call_count == 0 + assert remote.close.call_count == 0 + assert remote.call_count == 1 async def test_play_media_channel_as_non_positive(hass, remote): """Test for play_media with invalid channel as non positive integer.""" - await setup_samsungtv(hass, MOCK_CONFIG) - assert await hass.services.async_call( - DOMAIN, - SERVICE_PLAY_MEDIA, - { - ATTR_ENTITY_ID: ENTITY_ID, - ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_CHANNEL, - ATTR_MEDIA_CONTENT_ID: "-4", - }, - True, - ) - # only update called - assert remote.control.call_count == 1 - assert remote.control.call_args_list == [call("KEY")] + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote" + ) as remote, patch("homeassistant.components.samsungtv.config_flow.socket"): + await setup_samsungtv(hass, MOCK_CONFIG) + assert await hass.services.async_call( + DOMAIN, + SERVICE_PLAY_MEDIA, + { + ATTR_ENTITY_ID: ENTITY_ID, + ATTR_MEDIA_CONTENT_TYPE: MEDIA_TYPE_CHANNEL, + ATTR_MEDIA_CONTENT_ID: "-4", + }, + True, + ) + # only update called + assert remote.control.call_count == 0 + assert remote.close.call_count == 0 + assert remote.call_count == 1 async def test_select_source(hass, remote): @@ -527,19 +577,25 @@ async def test_select_source(hass, remote): True, ) # key and update called - assert remote.control.call_count == 2 - assert remote.control.call_args_list == [call("KEY_HDMI"), call("KEY")] + assert remote.control.call_count == 1 + assert remote.control.call_args_list == [call("KEY_HDMI")] + assert remote.close.call_count == 1 + assert remote.close.call_args_list == [call()] async def test_select_source_invalid_source(hass, remote): """Test for select_source with invalid source.""" - await setup_samsungtv(hass, MOCK_CONFIG) - assert await hass.services.async_call( - DOMAIN, - SERVICE_SELECT_SOURCE, - {ATTR_ENTITY_ID: ENTITY_ID, ATTR_INPUT_SOURCE: "INVALID"}, - True, - ) - # only update called - assert remote.control.call_count == 1 - assert remote.control.call_args_list == [call("KEY")] + with patch( + "homeassistant.components.samsungtv.media_player.SamsungRemote" + ) as remote, patch("homeassistant.components.samsungtv.config_flow.socket"): + await setup_samsungtv(hass, MOCK_CONFIG) + assert await hass.services.async_call( + DOMAIN, + SERVICE_SELECT_SOURCE, + {ATTR_ENTITY_ID: ENTITY_ID, ATTR_INPUT_SOURCE: "INVALID"}, + True, + ) + # only update called + assert remote.control.call_count == 0 + assert remote.close.call_count == 0 + assert remote.call_count == 1