Raise exception for invalid call to DeviceRegistry.async_get_or_create (#49038)

* Raise exception instead of returning None for DeviceRegistry.async_get_or_create

* fix entity_platform logic
pull/49166/head
Raman Gupta 2021-04-13 08:18:51 -04:00 committed by GitHub
parent 2b79c91813
commit 769923e8dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 14 deletions

View File

@ -183,3 +183,18 @@ class MaxLengthExceeded(HomeAssistantError):
self.value = value
self.property_name = property_name
self.max_length = max_length
class RequiredParameterMissing(HomeAssistantError):
"""Raised when a required parameter is missing from a function call."""
def __init__(self, parameter_names: list[str]) -> None:
"""Initialize error."""
super().__init__(
self,
(
"Call must include at least one of the following parameters: "
f"{', '.join(parameter_names)}"
),
)
self.parameter_names = parameter_names

View File

@ -10,6 +10,7 @@ import attr
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import RequiredParameterMissing
from homeassistant.loader import bind_hass
import homeassistant.util.uuid as uuid_util
@ -259,10 +260,10 @@ class DeviceRegistry:
# To disable a device if it gets created
disabled_by: str | None | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
) -> DeviceEntry | None:
) -> DeviceEntry:
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
raise RequiredParameterMissing(["identifiers", "connections"])
if identifiers is None:
identifiers = set()
@ -300,7 +301,7 @@ class DeviceRegistry:
else:
via_device_id = UNDEFINED
return self._async_update_device(
device = self._async_update_device(
device.id,
add_config_entry_id=config_entry_id,
via_device_id=via_device_id,
@ -315,6 +316,11 @@ class DeviceRegistry:
suggested_area=suggested_area,
)
# This is safe because _async_update_device will always return a device
# in this use case.
assert device
return device
@callback
def async_update_device(
self,

View File

@ -24,7 +24,11 @@ from homeassistant.core import (
split_entity_id,
valid_entity_id,
)
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.exceptions import (
HomeAssistantError,
PlatformNotReady,
RequiredParameterMissing,
)
from homeassistant.helpers import (
config_validation as cv,
device_registry as dev_reg,
@ -434,9 +438,11 @@ class EntityPlatform:
if key in device_info:
processed_dev_info[key] = device_info[key]
device = device_registry.async_get_or_create(**processed_dev_info)
if device:
try:
device = device_registry.async_get_or_create(**processed_dev_info)
device_id = device.id
except RequiredParameterMissing:
pass
disabled_by: str | None = None
if not entity.entity_registry_enabled_default:

View File

@ -6,6 +6,7 @@ import pytest
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, callback
from homeassistant.exceptions import RequiredParameterMissing
from homeassistant.helpers import device_registry, entity_registry
from tests.common import (
@ -114,18 +115,21 @@ async def test_requirement_for_identifier_or_connection(registry):
manufacturer="manufacturer",
model="model",
)
entry3 = registry.async_get_or_create(
config_entry_id="1234",
connections=set(),
identifiers=set(),
manufacturer="manufacturer",
model="model",
)
assert len(registry.devices) == 2
assert entry
assert entry2
assert entry3 is None
with pytest.raises(RequiredParameterMissing) as exc_info:
registry.async_get_or_create(
config_entry_id="1234",
connections=set(),
identifiers=set(),
manufacturer="manufacturer",
model="model",
)
assert exc_info.value.parameter_names == ["identifiers", "connections"]
async def test_multiple_config_entries(registry):