Improve FrozenOrThawed (#105541)

pull/105598/head
Erik Montnemery 2023-12-12 21:19:41 +01:00 committed by GitHub
parent 8bd265c3ae
commit 5bd0833f49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 135 additions and 33 deletions

View File

@ -59,7 +59,7 @@ class FrozenOrThawed(type):
for base in bases: for base in bases:
dataclass_bases.append(getattr(base, "_dataclass", base)) dataclass_bases.append(getattr(base, "_dataclass", base))
cls._dataclass = dataclasses.make_dataclass( cls._dataclass = dataclasses.make_dataclass(
f"{name}_dataclass", class_fields, bases=tuple(dataclass_bases), frozen=True name, class_fields, bases=tuple(dataclass_bases), frozen=True
) )
def __new__( def __new__(
@ -87,15 +87,17 @@ class FrozenOrThawed(type):
class will be a real dataclass, i.e. it's decorated with @dataclass. class will be a real dataclass, i.e. it's decorated with @dataclass.
""" """
if not namespace["_FrozenOrThawed__frozen_or_thawed"]: if not namespace["_FrozenOrThawed__frozen_or_thawed"]:
parent = cls.__mro__[1]
# This class is a real dataclass, optionally inject the parent's annotations # This class is a real dataclass, optionally inject the parent's annotations
if dataclasses.is_dataclass(parent) or not hasattr(parent, "_dataclass"): if all(dataclasses.is_dataclass(base) for base in bases):
# Rely on dataclass inheritance # All direct parents are dataclasses, rely on dataclass inheritance
return return
# Parent is not a dataclass, inject its annotations # Parent is not a dataclass, inject all parents' annotations
cls.__annotations__ = ( annotations: dict = {}
parent._dataclass.__annotations__ | cls.__annotations__ for parent in cls.__mro__[::-1]:
) if parent is object:
continue
annotations |= parent.__annotations__
cls.__annotations__ = annotations
return return
# First try without setting the kw_only flag, and if that fails, try setting it # First try without setting the kw_only flag, and if that fails, try setting it
@ -104,30 +106,15 @@ class FrozenOrThawed(type):
except TypeError: except TypeError:
cls._make_dataclass(name, bases, True) cls._make_dataclass(name, bases, True)
def __delattr__(self: object, name: str) -> None: def __new__(*args: Any, **kwargs: Any) -> object:
"""Delete an attribute. """Create a new instance.
If self is a real dataclass, this is called if the dataclass is not frozen. The function has no named arguments to avoid name collisions with dataclass
If self is not a real dataclass, forward to cls._dataclass.__delattr. field names.
""" """
if dataclasses.is_dataclass(self): cls, *_args = args
return object.__delattr__(self, name) if dataclasses.is_dataclass(cls):
return self._dataclass.__delattr__(self, name) # type: ignore[attr-defined, no-any-return] return object.__new__(cls)
return cls._dataclass(*_args, **kwargs)
def __setattr__(self: object, name: str, value: Any) -> None: cls.__new__ = __new__ # type: ignore[method-assign]
"""Set an attribute.
If self is a real dataclass, this is called if the dataclass is not frozen.
If self is not a real dataclass, forward to cls._dataclass.__setattr__.
"""
if dataclasses.is_dataclass(self):
return object.__setattr__(self, name, value)
return self._dataclass.__setattr__(self, name, value) # type: ignore[attr-defined, no-any-return]
# Set generated dunder methods from the dataclass
# MyPy doesn't understand what's happening, so we ignore it
cls.__delattr__ = __delattr__ # type: ignore[assignment, method-assign]
cls.__eq__ = cls._dataclass.__eq__ # type: ignore[method-assign]
cls.__init__ = cls._dataclass.__init__ # type: ignore[misc]
cls.__repr__ = cls._dataclass.__repr__ # type: ignore[method-assign]
cls.__setattr__ = __setattr__ # type: ignore[assignment, method-assign]

View File

@ -1,6 +1,18 @@
# serializer version: 1 # serializer version: 1
# name: test_entity_description_as_dataclass # name: test_entity_description_as_dataclass
EntityDescription(key='blah', device_class='test', entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name=<UndefinedType._singleton: 0>, translation_key=None, unit_of_measurement=None) dict({
'device_class': 'test',
'entity_category': None,
'entity_registry_enabled_default': True,
'entity_registry_visible_default': True,
'force_update': False,
'has_entity_name': False,
'icon': None,
'key': 'blah',
'name': <UndefinedType._singleton: 0>,
'translation_key': None,
'unit_of_measurement': None,
})
# --- # ---
# name: test_entity_description_as_dataclass.1 # name: test_entity_description_as_dataclass.1
"EntityDescription(key='blah', device_class='test', entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name=<UndefinedType._singleton: 0>, translation_key=None, unit_of_measurement=None)" "EntityDescription(key='blah', device_class='test', entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name=<UndefinedType._singleton: 0>, translation_key=None, unit_of_measurement=None)"
@ -43,3 +55,63 @@
# name: test_extending_entity_description.3 # name: test_extending_entity_description.3
"test_extending_entity_description.<locals>.ThawedEntityDescription(key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')" "test_extending_entity_description.<locals>.ThawedEntityDescription(key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')"
# --- # ---
# name: test_extending_entity_description.4
dict({
'device_class': None,
'entity_category': None,
'entity_registry_enabled_default': True,
'entity_registry_visible_default': True,
'extension': 'ext',
'extra': 'foo',
'force_update': False,
'has_entity_name': False,
'icon': None,
'key': 'blah',
'name': 'name',
'translation_key': None,
'unit_of_measurement': None,
})
# ---
# name: test_extending_entity_description.5
"test_extending_entity_description.<locals>.MyExtendedEntityDescription(key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extension='ext', extra='foo')"
# ---
# name: test_extending_entity_description.6
dict({
'device_class': None,
'entity_category': None,
'entity_registry_enabled_default': True,
'entity_registry_visible_default': True,
'extra': 'foo',
'force_update': False,
'has_entity_name': False,
'icon': None,
'key': 'blah',
'mixin': 'mixin',
'name': 'name',
'translation_key': None,
'unit_of_measurement': None,
})
# ---
# name: test_extending_entity_description.7
"test_extending_entity_description.<locals>.ComplexEntityDescription1(mixin='mixin', key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')"
# ---
# name: test_extending_entity_description.8
dict({
'device_class': None,
'entity_category': None,
'entity_registry_enabled_default': True,
'entity_registry_visible_default': True,
'extra': 'foo',
'force_update': False,
'has_entity_name': False,
'icon': None,
'key': 'blah',
'mixin': 'mixin',
'name': 'name',
'translation_key': None,
'unit_of_measurement': None,
})
# ---
# name: test_extending_entity_description.9
"test_extending_entity_description.<locals>.ComplexEntityDescription2(mixin='mixin', key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')"
# ---

View File

@ -1669,6 +1669,7 @@ def test_entity_description_as_dataclass(snapshot: SnapshotAssertion):
with pytest.raises(dataclasses.FrozenInstanceError): with pytest.raises(dataclasses.FrozenInstanceError):
delattr(obj, "name") delattr(obj, "name")
assert dataclasses.is_dataclass(obj)
assert obj == snapshot assert obj == snapshot
assert obj == entity.EntityDescription("blah", device_class="test") assert obj == entity.EntityDescription("blah", device_class="test")
assert repr(obj) == snapshot assert repr(obj) == snapshot
@ -1706,3 +1707,45 @@ def test_extending_entity_description(snapshot: SnapshotAssertion):
assert obj.name == "mutate" assert obj.name == "mutate"
delattr(obj, "key") delattr(obj, "key")
assert not hasattr(obj, "key") assert not hasattr(obj, "key")
# Try multiple levels of FrozenOrThawed
class ExtendedEntityDescription(entity.EntityDescription, frozen_or_thawed=True):
extension: str = None
@dataclasses.dataclass(frozen=True)
class MyExtendedEntityDescription(ExtendedEntityDescription):
extra: str = None
obj = MyExtendedEntityDescription("blah", extension="ext", extra="foo", name="name")
assert obj == snapshot
assert obj == MyExtendedEntityDescription(
"blah", extension="ext", extra="foo", name="name"
)
assert repr(obj) == snapshot
# Try multiple direct parents
@dataclasses.dataclass(frozen=True)
class MyMixin:
mixin: str = None
@dataclasses.dataclass(frozen=True, kw_only=True)
class ComplexEntityDescription1(MyMixin, entity.EntityDescription):
extra: str = None
obj = ComplexEntityDescription1(key="blah", extra="foo", mixin="mixin", name="name")
assert obj == snapshot
assert obj == ComplexEntityDescription1(
key="blah", extra="foo", mixin="mixin", name="name"
)
assert repr(obj) == snapshot
@dataclasses.dataclass(frozen=True, kw_only=True)
class ComplexEntityDescription2(entity.EntityDescription, MyMixin):
extra: str = None
obj = ComplexEntityDescription2(key="blah", extra="foo", mixin="mixin", name="name")
assert obj == snapshot
assert obj == ComplexEntityDescription2(
key="blah", extra="foo", mixin="mixin", name="name"
)
assert repr(obj) == snapshot