diff --git a/homeassistant/components/media_player/cast.py b/homeassistant/components/media_player/cast.py index 8468390c590..cf764fd723e 100644 --- a/homeassistant/components/media_player/cast.py +++ b/homeassistant/components/media_player/cast.py @@ -68,12 +68,19 @@ def setup_platform(hass, config, add_devices, discovery_info=None): casts = [] + # get_chromecasts() returns Chromecast objects + # with the correct friendly name for grouped devices + all_chromecasts = pychromecast.get_chromecasts() + for host in hosts: - try: - casts.append(CastDevice(*host)) - KNOWN_HOSTS.append(host) - except pychromecast.ChromecastConnectionError: - pass + found = [device for device in all_chromecasts + if (device.host, device.port) == host] + if found: + try: + casts.append(CastDevice(found[0])) + KNOWN_HOSTS.append(host) + except pychromecast.ChromecastConnectionError: + pass add_devices(casts) @@ -83,10 +90,9 @@ class CastDevice(MediaPlayerDevice): # pylint: disable=abstract-method # pylint: disable=too-many-public-methods - def __init__(self, host, port): + def __init__(self, chromecast): """Initialize the Cast device.""" - import pychromecast - self.cast = pychromecast.Chromecast(host, port) + self.cast = chromecast self.cast.socket_client.receiver_controller.register_status_listener( self) diff --git a/tests/components/media_player/test_cast.py b/tests/components/media_player/test_cast.py index 9930ae678f3..b4d4b15351c 100644 --- a/tests/components/media_player/test_cast.py +++ b/tests/components/media_player/test_cast.py @@ -6,12 +6,25 @@ from unittest.mock import patch from homeassistant.components.media_player import cast +class FakeChromeCast(object): + def __init__(self, host, port): + self.host = host + self.port = port + + class TestCastMediaPlayer(unittest.TestCase): """Test the media_player module.""" @patch('homeassistant.components.media_player.cast.CastDevice') - def test_filter_duplicates(self, mock_device): + @patch('pychromecast.get_chromecasts') + def test_filter_duplicates(self, mock_get_chromecasts, mock_device): """Test filtering of duplicates.""" + + mock_get_chromecasts.return_value = [ + FakeChromeCast('some_host', cast.DEFAULT_PORT) + ] + + # Test chromecasts as if they were hardcoded in configuration.yaml cast.setup_platform(None, { 'host': 'some_host' }, lambda _: _) @@ -21,6 +34,7 @@ class TestCastMediaPlayer(unittest.TestCase): mock_device.reset_mock() assert not mock_device.called + # Test chromecasts as if they were automatically discovered cast.setup_platform(None, {}, lambda _: _, ('some_host', cast.DEFAULT_PORT)) assert not mock_device.called