Refactor service enumeration methods to better match existing use cases (#108671)

pull/108699/head
J. Nick Koston 2024-01-22 14:21:17 -10:00 committed by GitHub
parent f6bc5c98b3
commit 7c86ab14c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 62 additions and 12 deletions

View File

@ -152,7 +152,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
"""Return list of notify services."""
unordered_services = set()
for service in self.hass.services.async_services().get("notify", {}):
for service in self.hass.services.async_services_for_domain("notify"):
if service not in self._exclude:
unordered_services.add(service)

View File

@ -97,7 +97,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services()[DOMAIN]:
for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name)
conversation.async_unset_agent(hass, entry)

View File

@ -64,7 +64,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services()[DOMAIN]:
for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name)
return unload_ok

View File

@ -81,7 +81,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
for service_name in hass.services.async_services()[DOMAIN]:
for service_name in hass.services.async_services_for_domain(DOMAIN):
hass.services.async_remove(DOMAIN, service_name)
return True

View File

@ -344,7 +344,7 @@ async def async_setup(hass: ha.HomeAssistant, config: ConfigType) -> bool: # no
f"configuration is not valid: {errors}"
)
services = hass.services.async_services()
services = hass.services.async_services_internal()
tasks = [
hass.services.async_call(
domain, SERVICE_RELOAD, context=call.context, blocking=True

View File

@ -110,7 +110,7 @@ SCHEMA_RESET_ENERGY_COUNTER = vol.Schema(
async def async_setup_services(hass: HomeAssistant) -> None:
"""Set up the HomematicIP Cloud services."""
if hass.services.async_services().get(HMIPC_DOMAIN):
if hass.services.async_services_for_domain(HMIPC_DOMAIN):
return
@verify_domain_control(hass, HMIPC_DOMAIN)

View File

@ -132,7 +132,7 @@ def async_get_entities(hass: HomeAssistant) -> dict[str, Entity]:
@callback
def async_setup_services(hass: HomeAssistant) -> None: # noqa: C901
"""Create and register services for the ISY integration."""
existing_services = hass.services.async_services().get(DOMAIN)
existing_services = hass.services.async_services_for_domain(DOMAIN)
if existing_services and SERVICE_SEND_PROGRAM_COMMAND in existing_services:
# Integration-level services have already been added. Return.
return
@ -234,7 +234,7 @@ def async_unload_services(hass: HomeAssistant) -> None:
# There is still another config entry for this domain, don't remove services.
return
existing_services = hass.services.async_services().get(DOMAIN)
existing_services = hass.services.async_services_for_domain(DOMAIN)
if not existing_services or SERVICE_SEND_PROGRAM_COMMAND not in existing_services:
return

View File

@ -79,7 +79,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if conf is None:
return
existing = hass.services.async_services().get(DOMAIN, {})
existing = hass.services.async_services_for_domain(DOMAIN)
for existing_service in existing:
if existing_service == SERVICE_RELOAD:
continue

View File

@ -2019,10 +2019,36 @@ class ServiceRegistry:
def async_services(self) -> dict[str, dict[str, Service]]:
"""Return dictionary with per domain a list of available services.
This method makes a copy of the registry. This function is expensive,
and should only be used if has_service is not sufficient.
This method must be run in the event loop.
"""
return {domain: service.copy() for domain, service in self._services.items()}
@callback
def async_services_for_domain(self, domain: str) -> dict[str, Service]:
"""Return dictionary with per domain a list of available services.
This method makes a copy of the registry for the domain.
This method must be run in the event loop.
"""
return self._services.get(domain, {}).copy()
@callback
def async_services_internal(self) -> dict[str, dict[str, Service]]:
"""Return dictionary with per domain a list of available services.
This method DOES NOT make a copy of the services like async_services does.
It is only expected to be called from the Home Assistant internals
as a performance optimization when the caller is not going to modify the
returned data.
This method must be run in the event loop.
"""
return self._services
def has_service(self, domain: str, service: str) -> bool:
"""Test if specified service exists.

View File

@ -585,7 +585,7 @@ async def async_get_all_descriptions(
# We don't mutate services here so we avoid calling
# async_services which makes a copy of every services
# dict.
services = hass.services._services # pylint: disable=protected-access
services = hass.services.async_services_internal()
# See if there are new services not seen before.
# Any service that we saw before already has an entry in description_cache.

View File

@ -256,7 +256,7 @@ async def test_turn_on_skips_domains_without_service(
"turn_on",
{"entity_id": ["light.test", "sensor.bla", "binary_sensor.blub", "light.bla"]},
)
service = hass.services._services["homeassistant"]["turn_on"]
service = hass.services.async_services_for_domain("homeassistant")["turn_on"]
with patch(
"homeassistant.core.ServiceRegistry.async_call",

View File

@ -1269,7 +1269,7 @@ def test_service_call_repr() -> None:
)
async def test_serviceregistry_has_service(hass: HomeAssistant) -> None:
async def test_service_registry_has_service(hass: HomeAssistant) -> None:
"""Test has_service method."""
hass.services.async_register("test_domain", "test_service", lambda call: None)
assert len(hass.services.async_services()) == 1
@ -1278,6 +1278,30 @@ async def test_serviceregistry_has_service(hass: HomeAssistant) -> None:
assert not hass.services.has_service("non_existing", "test_service")
async def test_service_registry_service_enumeration(hass: HomeAssistant) -> None:
"""Test enumerating services methods."""
hass.services.async_register("test_domain", "test_service", lambda call: None)
services1 = hass.services.async_services()
services2 = hass.services.async_services()
assert len(services1) == 1
assert services1 == services2
assert services1 is not services2 # should be a copy
services1 = hass.services.async_services_internal()
services2 = hass.services.async_services_internal()
assert len(services1) == 1
assert services1 == services2
assert services1 is services2 # should be the same object
assert hass.services.async_services_for_domain("unknown") == {}
services1 = hass.services.async_services_for_domain("test_domain")
services2 = hass.services.async_services_for_domain("test_domain")
assert len(services1) == 1
assert services1 == services2
assert services1 is not services2 # should be a copy
async def test_serviceregistry_call_with_blocking_done_in_time(
hass: HomeAssistant,
) -> None: