490 lines
16 KiB
Python
490 lines
16 KiB
Python
"""The tests for the Entity component helper."""
|
|
# pylint: disable=protected-access
|
|
from collections import OrderedDict
|
|
from datetime import timedelta
|
|
import logging
|
|
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.const import ENTITY_MATCH_ALL, ENTITY_MATCH_NONE
|
|
import homeassistant.core as ha
|
|
from homeassistant.exceptions import PlatformNotReady
|
|
from homeassistant.helpers import discovery
|
|
from homeassistant.helpers.entity_component import EntityComponent
|
|
from homeassistant.setup import async_setup_component
|
|
import homeassistant.util.dt as dt_util
|
|
|
|
from tests.async_mock import AsyncMock, Mock, patch
|
|
from tests.common import (
|
|
MockConfigEntry,
|
|
MockEntity,
|
|
MockModule,
|
|
MockPlatform,
|
|
async_fire_time_changed,
|
|
mock_entity_platform,
|
|
mock_integration,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
DOMAIN = "test_domain"
|
|
|
|
|
|
async def test_setup_loads_platforms(hass):
|
|
"""Test the loading of the platforms."""
|
|
component_setup = Mock(return_value=True)
|
|
platform_setup = Mock(return_value=None)
|
|
|
|
mock_integration(hass, MockModule("test_component", setup=component_setup))
|
|
# mock the dependencies
|
|
mock_integration(hass, MockModule("mod2", dependencies=["test_component"]))
|
|
mock_entity_platform(hass, "test_domain.mod2", MockPlatform(platform_setup))
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
assert not component_setup.called
|
|
assert not platform_setup.called
|
|
|
|
component.setup({DOMAIN: {"platform": "mod2"}})
|
|
|
|
await hass.async_block_till_done()
|
|
assert component_setup.called
|
|
assert platform_setup.called
|
|
|
|
|
|
async def test_setup_recovers_when_setup_raises(hass):
|
|
"""Test the setup if exceptions are happening."""
|
|
platform1_setup = Mock(side_effect=Exception("Broken"))
|
|
platform2_setup = Mock(return_value=None)
|
|
|
|
mock_entity_platform(hass, "test_domain.mod1", MockPlatform(platform1_setup))
|
|
mock_entity_platform(hass, "test_domain.mod2", MockPlatform(platform2_setup))
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
assert not platform1_setup.called
|
|
assert not platform2_setup.called
|
|
|
|
component.setup(
|
|
OrderedDict(
|
|
[
|
|
(DOMAIN, {"platform": "mod1"}),
|
|
(f"{DOMAIN} 2", {"platform": "non_exist"}),
|
|
(f"{DOMAIN} 3", {"platform": "mod2"}),
|
|
]
|
|
)
|
|
)
|
|
|
|
await hass.async_block_till_done()
|
|
assert platform1_setup.called
|
|
assert platform2_setup.called
|
|
|
|
|
|
@patch(
|
|
"homeassistant.helpers.entity_component.EntityComponent.async_setup_platform",
|
|
)
|
|
@patch("homeassistant.setup.async_setup_component", return_value=True)
|
|
async def test_setup_does_discovery(mock_setup_component, mock_setup, hass):
|
|
"""Test setup for discovery."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
component.setup({})
|
|
|
|
discovery.load_platform(
|
|
hass, DOMAIN, "platform_test", {"msg": "discovery_info"}, {DOMAIN: {}}
|
|
)
|
|
|
|
await hass.async_block_till_done()
|
|
|
|
assert mock_setup.called
|
|
assert ("platform_test", {}, {"msg": "discovery_info"}) == mock_setup.call_args[0]
|
|
|
|
|
|
@patch("homeassistant.helpers.entity_platform.async_track_time_interval")
|
|
async def test_set_scan_interval_via_config(mock_track, hass):
|
|
"""Test the setting of the scan interval via configuration."""
|
|
|
|
def platform_setup(hass, config, add_entities, discovery_info=None):
|
|
"""Test the platform setup."""
|
|
add_entities([MockEntity(should_poll=True)])
|
|
|
|
mock_entity_platform(hass, "test_domain.platform", MockPlatform(platform_setup))
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
component.setup(
|
|
{DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}}
|
|
)
|
|
|
|
await hass.async_block_till_done()
|
|
assert mock_track.called
|
|
assert timedelta(seconds=30) == mock_track.call_args[0][2]
|
|
|
|
|
|
async def test_set_entity_namespace_via_config(hass):
|
|
"""Test setting an entity namespace."""
|
|
|
|
def platform_setup(hass, config, add_entities, discovery_info=None):
|
|
"""Test the platform setup."""
|
|
add_entities([MockEntity(name="beer"), MockEntity(name=None)])
|
|
|
|
platform = MockPlatform(platform_setup)
|
|
|
|
mock_entity_platform(hass, "test_domain.platform", platform)
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
component.setup({DOMAIN: {"platform": "platform", "entity_namespace": "yummy"}})
|
|
|
|
await hass.async_block_till_done()
|
|
|
|
assert sorted(hass.states.async_entity_ids()) == [
|
|
"test_domain.yummy_beer",
|
|
"test_domain.yummy_unnamed_device",
|
|
]
|
|
|
|
|
|
async def test_extract_from_service_available_device(hass):
|
|
"""Test the extraction of entity from service and device is available."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities(
|
|
[
|
|
MockEntity(name="test_1"),
|
|
MockEntity(name="test_2", available=False),
|
|
MockEntity(name="test_3"),
|
|
MockEntity(name="test_4", available=False),
|
|
]
|
|
)
|
|
|
|
call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL})
|
|
|
|
assert ["test_domain.test_1", "test_domain.test_3"] == sorted(
|
|
ent.entity_id for ent in (await component.async_extract_from_service(call_1))
|
|
)
|
|
|
|
call_2 = ha.ServiceCall(
|
|
"test",
|
|
"service",
|
|
data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]},
|
|
)
|
|
|
|
assert ["test_domain.test_3"] == sorted(
|
|
ent.entity_id for ent in (await component.async_extract_from_service(call_2))
|
|
)
|
|
|
|
|
|
async def test_platform_not_ready(hass, legacy_patchable_time):
|
|
"""Test that we retry when platform not ready."""
|
|
platform1_setup = Mock(side_effect=[PlatformNotReady, PlatformNotReady, None])
|
|
mock_integration(hass, MockModule("mod1"))
|
|
mock_entity_platform(hass, "test_domain.mod1", MockPlatform(platform1_setup))
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
await component.async_setup({DOMAIN: {"platform": "mod1"}})
|
|
await hass.async_block_till_done()
|
|
assert len(platform1_setup.mock_calls) == 1
|
|
assert "test_domain.mod1" not in hass.config.components
|
|
|
|
utcnow = dt_util.utcnow()
|
|
|
|
with patch("homeassistant.util.dt.utcnow", return_value=utcnow):
|
|
# Should not trigger attempt 2
|
|
async_fire_time_changed(hass, utcnow + timedelta(seconds=29))
|
|
await hass.async_block_till_done()
|
|
assert len(platform1_setup.mock_calls) == 1
|
|
|
|
# Should trigger attempt 2
|
|
async_fire_time_changed(hass, utcnow + timedelta(seconds=30))
|
|
await hass.async_block_till_done()
|
|
assert len(platform1_setup.mock_calls) == 2
|
|
assert "test_domain.mod1" not in hass.config.components
|
|
|
|
# This should not trigger attempt 3
|
|
async_fire_time_changed(hass, utcnow + timedelta(seconds=59))
|
|
await hass.async_block_till_done()
|
|
assert len(platform1_setup.mock_calls) == 2
|
|
|
|
# Trigger attempt 3, which succeeds
|
|
async_fire_time_changed(hass, utcnow + timedelta(seconds=60))
|
|
await hass.async_block_till_done()
|
|
assert len(platform1_setup.mock_calls) == 3
|
|
assert "test_domain.mod1" in hass.config.components
|
|
|
|
|
|
async def test_extract_from_service_fails_if_no_entity_id(hass):
|
|
"""Test the extraction of everything from service."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities(
|
|
[MockEntity(name="test_1"), MockEntity(name="test_2")]
|
|
)
|
|
|
|
assert (
|
|
await component.async_extract_from_service(ha.ServiceCall("test", "service"))
|
|
== []
|
|
)
|
|
assert (
|
|
await component.async_extract_from_service(
|
|
ha.ServiceCall("test", "service", {"entity_id": ENTITY_MATCH_NONE})
|
|
)
|
|
== []
|
|
)
|
|
assert (
|
|
await component.async_extract_from_service(
|
|
ha.ServiceCall("test", "service", {"area_id": ENTITY_MATCH_NONE})
|
|
)
|
|
== []
|
|
)
|
|
|
|
|
|
async def test_extract_from_service_filter_out_non_existing_entities(hass):
|
|
"""Test the extraction of non existing entities from service."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities(
|
|
[MockEntity(name="test_1"), MockEntity(name="test_2")]
|
|
)
|
|
|
|
call = ha.ServiceCall(
|
|
"test",
|
|
"service",
|
|
{"entity_id": ["test_domain.test_2", "test_domain.non_exist"]},
|
|
)
|
|
|
|
assert ["test_domain.test_2"] == [
|
|
ent.entity_id for ent in await component.async_extract_from_service(call)
|
|
]
|
|
|
|
|
|
async def test_extract_from_service_no_group_expand(hass):
|
|
"""Test not expanding a group."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities([MockEntity(entity_id="group.test_group")])
|
|
|
|
call = ha.ServiceCall("test", "service", {"entity_id": ["group.test_group"]})
|
|
|
|
extracted = await component.async_extract_from_service(call, expand_group=False)
|
|
assert len(extracted) == 1
|
|
assert extracted[0].entity_id == "group.test_group"
|
|
|
|
|
|
async def test_setup_dependencies_platform(hass):
|
|
"""Test we setup the dependencies of a platform.
|
|
|
|
We're explicitly testing that we process dependencies even if a component
|
|
with the same name has already been loaded.
|
|
"""
|
|
mock_integration(
|
|
hass, MockModule("test_component", dependencies=["test_component2"])
|
|
)
|
|
mock_integration(hass, MockModule("test_component2"))
|
|
mock_entity_platform(hass, "test_domain.test_component", MockPlatform())
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
await component.async_setup({DOMAIN: {"platform": "test_component"}})
|
|
await hass.async_block_till_done()
|
|
assert "test_component" in hass.config.components
|
|
assert "test_component2" in hass.config.components
|
|
assert "test_domain.test_component" in hass.config.components
|
|
|
|
|
|
async def test_setup_entry(hass):
|
|
"""Test setup entry calls async_setup_entry on platform."""
|
|
mock_setup_entry = AsyncMock(return_value=True)
|
|
mock_entity_platform(
|
|
hass,
|
|
"test_domain.entry_domain",
|
|
MockPlatform(
|
|
async_setup_entry=mock_setup_entry, scan_interval=timedelta(seconds=5)
|
|
),
|
|
)
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entry = MockConfigEntry(domain="entry_domain")
|
|
|
|
assert await component.async_setup_entry(entry)
|
|
assert len(mock_setup_entry.mock_calls) == 1
|
|
p_hass, p_entry, _ = mock_setup_entry.mock_calls[0][1]
|
|
assert p_hass is hass
|
|
assert p_entry is entry
|
|
|
|
assert component._platforms[entry.entry_id].scan_interval == timedelta(seconds=5)
|
|
|
|
|
|
async def test_setup_entry_platform_not_exist(hass):
|
|
"""Test setup entry fails if platform does not exist."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entry = MockConfigEntry(domain="non_existing")
|
|
|
|
assert (await component.async_setup_entry(entry)) is False
|
|
|
|
|
|
async def test_setup_entry_fails_duplicate(hass):
|
|
"""Test we don't allow setting up a config entry twice."""
|
|
mock_setup_entry = AsyncMock(return_value=True)
|
|
mock_entity_platform(
|
|
hass,
|
|
"test_domain.entry_domain",
|
|
MockPlatform(async_setup_entry=mock_setup_entry),
|
|
)
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entry = MockConfigEntry(domain="entry_domain")
|
|
|
|
assert await component.async_setup_entry(entry)
|
|
|
|
with pytest.raises(ValueError):
|
|
await component.async_setup_entry(entry)
|
|
|
|
|
|
async def test_unload_entry_resets_platform(hass):
|
|
"""Test unloading an entry removes all entities."""
|
|
mock_setup_entry = AsyncMock(return_value=True)
|
|
mock_entity_platform(
|
|
hass,
|
|
"test_domain.entry_domain",
|
|
MockPlatform(async_setup_entry=mock_setup_entry),
|
|
)
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entry = MockConfigEntry(domain="entry_domain")
|
|
|
|
assert await component.async_setup_entry(entry)
|
|
assert len(mock_setup_entry.mock_calls) == 1
|
|
add_entities = mock_setup_entry.mock_calls[0][1][2]
|
|
add_entities([MockEntity()])
|
|
await hass.async_block_till_done()
|
|
|
|
assert len(hass.states.async_entity_ids()) == 1
|
|
|
|
assert await component.async_unload_entry(entry)
|
|
assert len(hass.states.async_entity_ids()) == 0
|
|
|
|
|
|
async def test_unload_entry_fails_if_never_loaded(hass):
|
|
"""."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entry = MockConfigEntry(domain="entry_domain")
|
|
|
|
with pytest.raises(ValueError):
|
|
await component.async_unload_entry(entry)
|
|
|
|
|
|
async def test_update_entity(hass):
|
|
"""Test that we can update an entity with the helper."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
entity = MockEntity()
|
|
entity.async_write_ha_state = Mock()
|
|
entity.async_update_ha_state = AsyncMock(return_value=None)
|
|
await component.async_add_entities([entity])
|
|
|
|
# Called as part of async_add_entities
|
|
assert len(entity.async_write_ha_state.mock_calls) == 1
|
|
|
|
await hass.helpers.entity_component.async_update_entity(entity.entity_id)
|
|
|
|
assert len(entity.async_update_ha_state.mock_calls) == 1
|
|
assert entity.async_update_ha_state.mock_calls[-1][1][0] is True
|
|
|
|
|
|
async def test_set_service_race(hass):
|
|
"""Test race condition on setting service."""
|
|
exception = False
|
|
|
|
def async_loop_exception_handler(_, _2) -> None:
|
|
"""Handle all exception inside the core loop."""
|
|
nonlocal exception
|
|
exception = True
|
|
|
|
hass.loop.set_exception_handler(async_loop_exception_handler)
|
|
|
|
await async_setup_component(hass, "group", {})
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
|
|
for _ in range(2):
|
|
hass.async_create_task(component.async_add_entities([MockEntity()]))
|
|
|
|
await hass.async_block_till_done()
|
|
assert not exception
|
|
|
|
|
|
async def test_extract_all_omit_entity_id(hass, caplog):
|
|
"""Test extract all with None and *."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities(
|
|
[MockEntity(name="test_1"), MockEntity(name="test_2")]
|
|
)
|
|
|
|
call = ha.ServiceCall("test", "service")
|
|
|
|
assert [] == sorted(
|
|
ent.entity_id for ent in await component.async_extract_from_service(call)
|
|
)
|
|
|
|
|
|
async def test_extract_all_use_match_all(hass, caplog):
|
|
"""Test extract all with None and *."""
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities(
|
|
[MockEntity(name="test_1"), MockEntity(name="test_2")]
|
|
)
|
|
|
|
call = ha.ServiceCall("test", "service", {"entity_id": "all"})
|
|
|
|
assert ["test_domain.test_1", "test_domain.test_2"] == sorted(
|
|
ent.entity_id for ent in await component.async_extract_from_service(call)
|
|
)
|
|
assert (
|
|
"Not passing an entity ID to a service to target all entities is deprecated"
|
|
) not in caplog.text
|
|
|
|
|
|
async def test_register_entity_service(hass):
|
|
"""Test not expanding a group."""
|
|
entity = MockEntity(entity_id=f"{DOMAIN}.entity")
|
|
calls = []
|
|
|
|
@ha.callback
|
|
def appender(**kwargs):
|
|
calls.append(kwargs)
|
|
|
|
entity.async_called_by_service = appender
|
|
|
|
component = EntityComponent(_LOGGER, DOMAIN, hass)
|
|
await component.async_add_entities([entity])
|
|
|
|
component.async_register_entity_service(
|
|
"hello", {"some": str}, "async_called_by_service"
|
|
)
|
|
|
|
with pytest.raises(vol.Invalid):
|
|
await hass.services.async_call(
|
|
DOMAIN,
|
|
"hello",
|
|
{"entity_id": entity.entity_id, "invalid": "data"},
|
|
blocking=True,
|
|
)
|
|
assert len(calls) == 0
|
|
|
|
await hass.services.async_call(
|
|
DOMAIN, "hello", {"entity_id": entity.entity_id, "some": "data"}, blocking=True
|
|
)
|
|
assert len(calls) == 1
|
|
assert calls[0] == {"some": "data"}
|
|
|
|
await hass.services.async_call(
|
|
DOMAIN, "hello", {"entity_id": ENTITY_MATCH_ALL, "some": "data"}, blocking=True
|
|
)
|
|
assert len(calls) == 2
|
|
assert calls[1] == {"some": "data"}
|
|
|
|
await hass.services.async_call(
|
|
DOMAIN, "hello", {"entity_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True
|
|
)
|
|
assert len(calls) == 2
|
|
|
|
await hass.services.async_call(
|
|
DOMAIN, "hello", {"area_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True
|
|
)
|
|
assert len(calls) == 2
|