diff --git a/homeassistant/components/alexa/entities.py b/homeassistant/components/alexa/entities.py index d84848e9aba..e52e1b4f87e 100644 --- a/homeassistant/components/alexa/entities.py +++ b/homeassistant/components/alexa/entities.py @@ -391,6 +391,10 @@ class MediaPlayerCapabilities(AlexaEntity): def default_display_categories(self): """Return the display categories for this entity.""" + device_class = self.entity.attributes.get(ATTR_DEVICE_CLASS) + if device_class == media_player.DEVICE_CLASS_SPEAKER: + return [DisplayCategory.SPEAKER] + return [DisplayCategory.TV] def interfaces(self): diff --git a/tests/components/alexa/test_smart_home.py b/tests/components/alexa/test_smart_home.py index c50c0748147..1de7d404ef6 100644 --- a/tests/components/alexa/test_smart_home.py +++ b/tests/components/alexa/test_smart_home.py @@ -968,6 +968,25 @@ async def test_media_player_power(hass): ) +async def test_media_player_speaker(hass): + """Test media player discovery with device class speaker.""" + device = ( + "media_player.test", + "off", + { + "friendly_name": "Test media player", + "supported_features": 51765, + "volume_level": 0.75, + "device_class": "speaker", + }, + ) + appliance = await discovery_test(device, hass) + + assert appliance["endpointId"] == "media_player#test" + assert appliance["displayCategories"][0] == "SPEAKER" + assert appliance["friendlyName"] == "Test media player" + + async def test_alert(hass): """Test alert discovery.""" device = ("alert.test", "off", {"friendly_name": "Test alert"})