Fix service registration supported features check (#35718)

pull/36090/head
Robert Chmielowiec 2020-05-23 18:11:51 +02:00 committed by GitHub
parent b0578018cb
commit d21cfd869e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 9 deletions

View File

@ -431,7 +431,8 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
# Skip entities that don't have the required feature.
if required_features is not None and not any(
entity.supported_features & feature_set for feature_set in required_features
entity.supported_features & feature_set == feature_set
for feature_set in required_features
):
continue

View File

@ -35,6 +35,10 @@ from tests.common import (
mock_service,
)
SUPPORT_A = 1
SUPPORT_B = 2
SUPPORT_C = 4
@pytest.fixture
def mock_handle_entity_call():
@ -52,17 +56,31 @@ def mock_entities(hass):
entity_id="light.kitchen",
available=True,
should_poll=False,
supported_features=1,
supported_features=SUPPORT_A,
)
living_room = MockEntity(
entity_id="light.living_room",
available=True,
should_poll=False,
supported_features=0,
supported_features=SUPPORT_B,
)
bedroom = MockEntity(
entity_id="light.bedroom",
available=True,
should_poll=False,
supported_features=(SUPPORT_A | SUPPORT_B),
)
bathroom = MockEntity(
entity_id="light.bathroom",
available=True,
should_poll=False,
supported_features=(SUPPORT_B | SUPPORT_C),
)
entities = OrderedDict()
entities[kitchen.entity_id] = kitchen
entities[living_room.entity_id] = living_room
entities[bedroom.entity_id] = bedroom
entities[bathroom.entity_id] = bathroom
return entities
@ -307,18 +325,61 @@ async def test_async_get_all_descriptions(hass):
async def test_call_with_required_features(hass, mock_entities):
"""Test service calls invoked only if entity has required feautres."""
"""Test service calls invoked only if entity has required features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[1],
required_features=[SUPPORT_A],
)
assert len(mock_entities) == 2
# Called once because only one of the entities had the required features
assert test_service_mock.call_count == 2
expected = [
mock_entities["light.kitchen"],
mock_entities["light.bedroom"],
]
actual = [call[0][0] for call in test_service_mock.call_args_list]
assert all(entity in actual for entity in expected)
async def test_call_with_both_required_features(hass, mock_entities):
"""Test service calls invoked only if entity has both features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A | SUPPORT_B],
)
assert test_service_mock.call_count == 1
assert [call[0][0] for call in test_service_mock.call_args_list] == [
mock_entities["light.bedroom"]
]
async def test_call_with_one_of_required_features(hass, mock_entities):
"""Test service calls invoked with one entity having the required features."""
test_service_mock = AsyncMock(return_value=None)
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
required_features=[SUPPORT_A, SUPPORT_C],
)
assert test_service_mock.call_count == 3
expected = [
mock_entities["light.kitchen"],
mock_entities["light.bedroom"],
mock_entities["light.bathroom"],
]
actual = [call[0][0] for call in test_service_mock.call_args_list]
assert all(entity in actual for entity in expected)
async def test_call_with_sync_func(hass, mock_entities):
@ -458,7 +519,7 @@ async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_en
),
)
assert len(mock_handle_entity_call.mock_calls) == 2
assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)
@ -494,7 +555,7 @@ async def test_call_with_match_all(
ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}),
)
assert len(mock_handle_entity_call.mock_calls) == 2
assert len(mock_handle_entity_call.mock_calls) == 4
assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list(
mock_entities.values()
)