diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index ce52d188540..af4bdb50fa4 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -431,7 +431,8 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non # Skip entities that don't have the required feature. if required_features is not None and not any( - entity.supported_features & feature_set for feature_set in required_features + entity.supported_features & feature_set == feature_set + for feature_set in required_features ): continue diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index e87fd2646dd..ba72cbc83ca 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -35,6 +35,10 @@ from tests.common import ( mock_service, ) +SUPPORT_A = 1 +SUPPORT_B = 2 +SUPPORT_C = 4 + @pytest.fixture def mock_handle_entity_call(): @@ -52,17 +56,31 @@ def mock_entities(hass): entity_id="light.kitchen", available=True, should_poll=False, - supported_features=1, + supported_features=SUPPORT_A, ) living_room = MockEntity( entity_id="light.living_room", available=True, should_poll=False, - supported_features=0, + supported_features=SUPPORT_B, + ) + bedroom = MockEntity( + entity_id="light.bedroom", + available=True, + should_poll=False, + supported_features=(SUPPORT_A | SUPPORT_B), + ) + bathroom = MockEntity( + entity_id="light.bathroom", + available=True, + should_poll=False, + supported_features=(SUPPORT_B | SUPPORT_C), ) entities = OrderedDict() entities[kitchen.entity_id] = kitchen entities[living_room.entity_id] = living_room + entities[bedroom.entity_id] = bedroom + entities[bathroom.entity_id] = bathroom return entities @@ -307,18 +325,61 @@ async def test_async_get_all_descriptions(hass): async def test_call_with_required_features(hass, mock_entities): - """Test service calls invoked only if entity has required feautres.""" + """Test service calls invoked only if entity has required features.""" test_service_mock = AsyncMock(return_value=None) await service.entity_service_call( hass, [Mock(entities=mock_entities)], test_service_mock, ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), - required_features=[1], + required_features=[SUPPORT_A], ) - assert len(mock_entities) == 2 - # Called once because only one of the entities had the required features + + assert test_service_mock.call_count == 2 + expected = [ + mock_entities["light.kitchen"], + mock_entities["light.bedroom"], + ] + actual = [call[0][0] for call in test_service_mock.call_args_list] + assert all(entity in actual for entity in expected) + + +async def test_call_with_both_required_features(hass, mock_entities): + """Test service calls invoked only if entity has both features.""" + test_service_mock = AsyncMock(return_value=None) + await service.entity_service_call( + hass, + [Mock(entities=mock_entities)], + test_service_mock, + ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), + required_features=[SUPPORT_A | SUPPORT_B], + ) + assert test_service_mock.call_count == 1 + assert [call[0][0] for call in test_service_mock.call_args_list] == [ + mock_entities["light.bedroom"] + ] + + +async def test_call_with_one_of_required_features(hass, mock_entities): + """Test service calls invoked with one entity having the required features.""" + test_service_mock = AsyncMock(return_value=None) + await service.entity_service_call( + hass, + [Mock(entities=mock_entities)], + test_service_mock, + ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), + required_features=[SUPPORT_A, SUPPORT_C], + ) + + assert test_service_mock.call_count == 3 + expected = [ + mock_entities["light.kitchen"], + mock_entities["light.bedroom"], + mock_entities["light.bathroom"], + ] + actual = [call[0][0] for call in test_service_mock.call_args_list] + assert all(entity in actual for entity in expected) async def test_call_with_sync_func(hass, mock_entities): @@ -458,7 +519,7 @@ async def test_call_no_context_target_all(hass, mock_handle_entity_call, mock_en ), ) - assert len(mock_handle_entity_call.mock_calls) == 2 + assert len(mock_handle_entity_call.mock_calls) == 4 assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( mock_entities.values() ) @@ -494,7 +555,7 @@ async def test_call_with_match_all( ha.ServiceCall("test_domain", "test_service", {"entity_id": "all"}), ) - assert len(mock_handle_entity_call.mock_calls) == 2 + assert len(mock_handle_entity_call.mock_calls) == 4 assert [call[1][1] for call in mock_handle_entity_call.mock_calls] == list( mock_entities.values() )