Use entity_sources to determine integration in recorder platforms (#88382)

pull/88389/head
Erik Montnemery 2023-02-18 14:21:41 +01:00 committed by GitHub
parent 728f0d5b3b
commit 83e5bf7ae8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 33 additions and 19 deletions

View File

@ -27,6 +27,7 @@ from .exceptions import HomeAssistantError
from .helpers import (
area_registry,
device_registry,
entity,
entity_registry,
issue_registry,
recorder,
@ -236,6 +237,7 @@ async def load_registries(hass: core.HomeAssistant) -> None:
platform.uname().processor # pylint: disable=expression-not-assigned
# Load the registries and cache the result of platform.uname().processor
entity.async_setup(hass)
await asyncio.gather(
area_registry.async_load(hass),
device_registry.async_load(hass),

View File

@ -30,7 +30,7 @@ from homeassistant.const import (
MATCH_ALL,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import entity_sources
from homeassistant.helpers.event import (
async_track_time_change,
async_track_time_interval,
@ -185,7 +185,7 @@ class Recorder(threading.Thread):
self._queue_watch = threading.Event()
self.engine: Engine | None = None
self.run_history = RunHistory()
self._entity_registry = entity_registry.async_get(hass)
self._entity_sources = entity_sources(hass)
# The entity_filter is exposed on the recorder instance so that
# it can be used to see if an entity is being recorded and is called
@ -878,7 +878,7 @@ class Recorder(threading.Thread):
dbstate = States.from_event(event)
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event,
self._entity_registry,
self._entity_sources,
self._exclude_attributes_by_domain,
self.dialect_name,
)

View File

@ -41,7 +41,6 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
import homeassistant.util.dt as dt_util
from homeassistant.util.json import (
@ -460,7 +459,7 @@ class StateAttributes(Base):
@staticmethod
def shared_attrs_bytes_from_event(
event: Event,
entity_registry: er.EntityRegistry,
entity_sources: dict[str, dict[str, str]],
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes:
@ -473,8 +472,8 @@ class StateAttributes(Base):
exclude_attrs = set(ALL_DOMAIN_EXCLUDE_ATTRS)
if base_platform_attrs := exclude_attrs_by_domain.get(domain):
exclude_attrs |= base_platform_attrs
if (reg_ent := entity_registry.async_get(state.entity_id)) and (
integration_attrs := exclude_attrs_by_domain.get(reg_ent.platform)
if (entity_info := entity_sources.get(state.entity_id)) and (
integration_attrs := exclude_attrs_by_domain.get(entity_info["domain"])
):
exclude_attrs |= integration_attrs
encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes

View File

@ -57,11 +57,18 @@ SOURCE_PLATFORM_CONFIG = "platform_config"
FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1
@callback
def async_setup(hass: HomeAssistant) -> None:
"""Set up entity sources."""
hass.data[DATA_ENTITY_SOURCE] = {}
@callback
@bind_hass
def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]:
"""Get the entity sources."""
return hass.data.get(DATA_ENTITY_SOURCE, {})
_entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE]
return _entity_sources
def generate_entity_id(
@ -868,7 +875,7 @@ class Entity(ABC):
else:
info["source"] = SOURCE_PLATFORM_CONFIG
self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info
self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info
if self.registry_entry is not None:
# This is an assert as it should never happen, but helps in tests

View File

@ -247,6 +247,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
)
# Load the registries
entity.async_setup(hass)
if load_registries:
with patch("homeassistant.helpers.storage.Store.async_load", return_value=None):
await asyncio.gather(
@ -1087,6 +1088,11 @@ class MockEntity(entity.Entity):
"""Return the entity category."""
return self._handle("entity_category")
@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return entity specific state attributes."""
return self._handle("extra_state_attributes")
@property
def has_entity_name(self) -> bool:
"""Return the has_entity_name name flag."""

View File

@ -75,6 +75,8 @@ from .common import (
)
from tests.common import (
MockEntity,
MockEntityPlatform,
async_fire_time_changed,
fire_time_changed,
get_test_home_assistant,
@ -2037,12 +2039,6 @@ async def test_excluding_attributes_by_integration(
"""Test that an integration's recorder platform can exclude attributes."""
state = "restoring_from_db"
attributes = {"test_attr": 5, "excluded": 10}
entry = entity_registry.async_get_or_create(
"test",
"fake_integration",
"recorder",
)
entity_id = entry.entity_id
mock_platform(
hass,
"fake_integration.recorder",
@ -2051,7 +2047,12 @@ async def test_excluding_attributes_by_integration(
hass.config.components.add("fake_integration")
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "fake_integration"})
await hass.async_block_till_done()
hass.states.async_set(entity_id, state, attributes)
entity_id = "test.fake_integration_recorder"
platform = MockEntityPlatform(hass, platform_name="fake_integration")
entity_platform = MockEntity(entity_id=entity_id, extra_state_attributes=attributes)
await platform.async_add_entities([entity_platform])
await async_wait_recording_done(hass)
with session_scope(hass=hass) as session:

View File

@ -26,7 +26,6 @@ from homeassistant.const import EVENT_STATE_CHANGED
import homeassistant.core as ha
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import InvalidEntityFormatError
from homeassistant.helpers import entity_registry as er
from homeassistant.util import dt, dt as dt_util
@ -50,7 +49,7 @@ def test_from_event_to_db_state() -> None:
assert state.as_dict() == States.from_event(event).to_native().as_dict()
def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -> None:
def test_from_event_to_db_state_attributes() -> None:
"""Test converting event to db state attributes."""
attrs = {"this_attr": True}
state = ha.State("sensor.temperature", "18", attrs)
@ -63,7 +62,7 @@ def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -
dialect = SupportedDialect.MYSQL
db_attrs.shared_attrs = StateAttributes.shared_attrs_bytes_from_event(
event, entity_registry, {}, dialect
event, {}, {}, dialect
)
assert db_attrs.to_native() == attrs