Use entity_sources to determine integration in recorder platforms (#88382)
parent
728f0d5b3b
commit
83e5bf7ae8
|
@ -27,6 +27,7 @@ from .exceptions import HomeAssistantError
|
|||
from .helpers import (
|
||||
area_registry,
|
||||
device_registry,
|
||||
entity,
|
||||
entity_registry,
|
||||
issue_registry,
|
||||
recorder,
|
||||
|
@ -236,6 +237,7 @@ async def load_registries(hass: core.HomeAssistant) -> None:
|
|||
platform.uname().processor # pylint: disable=expression-not-assigned
|
||||
|
||||
# Load the registries and cache the result of platform.uname().processor
|
||||
entity.async_setup(hass)
|
||||
await asyncio.gather(
|
||||
area_registry.async_load(hass),
|
||||
device_registry.async_load(hass),
|
||||
|
|
|
@ -30,7 +30,7 @@ from homeassistant.const import (
|
|||
MATCH_ALL,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
|
||||
from homeassistant.helpers import entity_registry
|
||||
from homeassistant.helpers.entity import entity_sources
|
||||
from homeassistant.helpers.event import (
|
||||
async_track_time_change,
|
||||
async_track_time_interval,
|
||||
|
@ -185,7 +185,7 @@ class Recorder(threading.Thread):
|
|||
self._queue_watch = threading.Event()
|
||||
self.engine: Engine | None = None
|
||||
self.run_history = RunHistory()
|
||||
self._entity_registry = entity_registry.async_get(hass)
|
||||
self._entity_sources = entity_sources(hass)
|
||||
|
||||
# The entity_filter is exposed on the recorder instance so that
|
||||
# it can be used to see if an entity is being recorded and is called
|
||||
|
@ -878,7 +878,7 @@ class Recorder(threading.Thread):
|
|||
dbstate = States.from_event(event)
|
||||
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
|
||||
event,
|
||||
self._entity_registry,
|
||||
self._entity_sources,
|
||||
self._exclude_attributes_by_domain,
|
||||
self.dialect_name,
|
||||
)
|
||||
|
|
|
@ -41,7 +41,6 @@ from homeassistant.const import (
|
|||
MAX_LENGTH_STATE_STATE,
|
||||
)
|
||||
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.json import (
|
||||
|
@ -460,7 +459,7 @@ class StateAttributes(Base):
|
|||
@staticmethod
|
||||
def shared_attrs_bytes_from_event(
|
||||
event: Event,
|
||||
entity_registry: er.EntityRegistry,
|
||||
entity_sources: dict[str, dict[str, str]],
|
||||
exclude_attrs_by_domain: dict[str, set[str]],
|
||||
dialect: SupportedDialect | None,
|
||||
) -> bytes:
|
||||
|
@ -473,8 +472,8 @@ class StateAttributes(Base):
|
|||
exclude_attrs = set(ALL_DOMAIN_EXCLUDE_ATTRS)
|
||||
if base_platform_attrs := exclude_attrs_by_domain.get(domain):
|
||||
exclude_attrs |= base_platform_attrs
|
||||
if (reg_ent := entity_registry.async_get(state.entity_id)) and (
|
||||
integration_attrs := exclude_attrs_by_domain.get(reg_ent.platform)
|
||||
if (entity_info := entity_sources.get(state.entity_id)) and (
|
||||
integration_attrs := exclude_attrs_by_domain.get(entity_info["domain"])
|
||||
):
|
||||
exclude_attrs |= integration_attrs
|
||||
encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes
|
||||
|
|
|
@ -57,11 +57,18 @@ SOURCE_PLATFORM_CONFIG = "platform_config"
|
|||
FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up entity sources."""
|
||||
hass.data[DATA_ENTITY_SOURCE] = {}
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]:
|
||||
"""Get the entity sources."""
|
||||
return hass.data.get(DATA_ENTITY_SOURCE, {})
|
||||
_entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE]
|
||||
return _entity_sources
|
||||
|
||||
|
||||
def generate_entity_id(
|
||||
|
@ -868,7 +875,7 @@ class Entity(ABC):
|
|||
else:
|
||||
info["source"] = SOURCE_PLATFORM_CONFIG
|
||||
|
||||
self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info
|
||||
self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info
|
||||
|
||||
if self.registry_entry is not None:
|
||||
# This is an assert as it should never happen, but helps in tests
|
||||
|
|
|
@ -247,6 +247,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
|
|||
)
|
||||
|
||||
# Load the registries
|
||||
entity.async_setup(hass)
|
||||
if load_registries:
|
||||
with patch("homeassistant.helpers.storage.Store.async_load", return_value=None):
|
||||
await asyncio.gather(
|
||||
|
@ -1087,6 +1088,11 @@ class MockEntity(entity.Entity):
|
|||
"""Return the entity category."""
|
||||
return self._handle("entity_category")
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> Mapping[str, Any] | None:
|
||||
"""Return entity specific state attributes."""
|
||||
return self._handle("extra_state_attributes")
|
||||
|
||||
@property
|
||||
def has_entity_name(self) -> bool:
|
||||
"""Return the has_entity_name name flag."""
|
||||
|
|
|
@ -75,6 +75,8 @@ from .common import (
|
|||
)
|
||||
|
||||
from tests.common import (
|
||||
MockEntity,
|
||||
MockEntityPlatform,
|
||||
async_fire_time_changed,
|
||||
fire_time_changed,
|
||||
get_test_home_assistant,
|
||||
|
@ -2037,12 +2039,6 @@ async def test_excluding_attributes_by_integration(
|
|||
"""Test that an integration's recorder platform can exclude attributes."""
|
||||
state = "restoring_from_db"
|
||||
attributes = {"test_attr": 5, "excluded": 10}
|
||||
entry = entity_registry.async_get_or_create(
|
||||
"test",
|
||||
"fake_integration",
|
||||
"recorder",
|
||||
)
|
||||
entity_id = entry.entity_id
|
||||
mock_platform(
|
||||
hass,
|
||||
"fake_integration.recorder",
|
||||
|
@ -2051,7 +2047,12 @@ async def test_excluding_attributes_by_integration(
|
|||
hass.config.components.add("fake_integration")
|
||||
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "fake_integration"})
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set(entity_id, state, attributes)
|
||||
|
||||
entity_id = "test.fake_integration_recorder"
|
||||
platform = MockEntityPlatform(hass, platform_name="fake_integration")
|
||||
entity_platform = MockEntity(entity_id=entity_id, extra_state_attributes=attributes)
|
||||
await platform.async_add_entities([entity_platform])
|
||||
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
with session_scope(hass=hass) as session:
|
||||
|
|
|
@ -26,7 +26,6 @@ from homeassistant.const import EVENT_STATE_CHANGED
|
|||
import homeassistant.core as ha
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import InvalidEntityFormatError
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.util import dt, dt as dt_util
|
||||
|
||||
|
||||
|
@ -50,7 +49,7 @@ def test_from_event_to_db_state() -> None:
|
|||
assert state.as_dict() == States.from_event(event).to_native().as_dict()
|
||||
|
||||
|
||||
def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -> None:
|
||||
def test_from_event_to_db_state_attributes() -> None:
|
||||
"""Test converting event to db state attributes."""
|
||||
attrs = {"this_attr": True}
|
||||
state = ha.State("sensor.temperature", "18", attrs)
|
||||
|
@ -63,7 +62,7 @@ def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -
|
|||
dialect = SupportedDialect.MYSQL
|
||||
|
||||
db_attrs.shared_attrs = StateAttributes.shared_attrs_bytes_from_event(
|
||||
event, entity_registry, {}, dialect
|
||||
event, {}, {}, dialect
|
||||
)
|
||||
assert db_attrs.to_native() == attrs
|
||||
|
||||
|
|
Loading…
Reference in New Issue