diff --git a/homeassistant/components/energy/validate.py b/homeassistant/components/energy/validate.py index b2a939bffce..d77ea75d36c 100644 --- a/homeassistant/components/energy/validate.py +++ b/homeassistant/components/energy/validate.py @@ -67,8 +67,10 @@ class EnergyPreferencesValidation: return dataclasses.asdict(self) -async def _async_validate_usage_stat( +@callback +def _async_validate_usage_stat( hass: HomeAssistant, + metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]], stat_id: str, allowed_device_classes: Sequence[str], allowed_units: Mapping[str, Sequence[str]], @@ -76,14 +78,6 @@ async def _async_validate_usage_stat( result: list[ValidationIssue], ) -> None: """Validate a statistic.""" - metadata = await hass.async_add_executor_job( - functools.partial( - recorder.statistics.get_metadata, - hass, - statistic_ids=(stat_id,), - ) - ) - if stat_id not in metadata: result.append(ValidationIssue("statistics_not_defined", stat_id)) @@ -201,18 +195,14 @@ def _async_validate_price_entity( result.append(ValidationIssue(unit_error, entity_id, unit)) -async def _async_validate_cost_stat( - hass: HomeAssistant, stat_id: str, result: list[ValidationIssue] +@callback +def _async_validate_cost_stat( + hass: HomeAssistant, + metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]], + stat_id: str, + result: list[ValidationIssue], ) -> None: """Validate that the cost stat is correct.""" - metadata = await hass.async_add_executor_job( - functools.partial( - recorder.statistics.get_metadata, - hass, - statistic_ids=(stat_id,), - ) - ) - if stat_id not in metadata: result.append(ValidationIssue("statistics_not_defined", stat_id)) @@ -266,154 +256,247 @@ def _async_validate_auto_generated_cost_entity( async def async_validate(hass: HomeAssistant) -> EnergyPreferencesValidation: """Validate the energy configuration.""" manager = await data.async_get_manager(hass) + statistics_metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]] = {} + validate_calls = [] + wanted_statistics_metadata = set() result = EnergyPreferencesValidation() if manager.data is None: return result + # Create a list of validation checks for source in manager.data["energy_sources"]: source_result: list[ValidationIssue] = [] result.energy_sources.append(source_result) if source["type"] == "grid": for flow in source["flow_from"]: - await _async_validate_usage_stat( - hass, - flow["stat_energy_from"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(flow["stat_energy_from"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + flow["stat_energy_from"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + source_result, + ) ) if flow.get("stat_cost") is not None: - await _async_validate_cost_stat( - hass, flow["stat_cost"], source_result + wanted_statistics_metadata.add(flow["stat_cost"]) + validate_calls.append( + functools.partial( + _async_validate_cost_stat, + hass, + statistics_metadata, + flow["stat_cost"], + source_result, + ) ) elif flow.get("entity_energy_price") is not None: - _async_validate_price_entity( - hass, - flow["entity_energy_price"], - source_result, - ENERGY_PRICE_UNITS, - ENERGY_PRICE_UNIT_ERROR, + validate_calls.append( + functools.partial( + _async_validate_price_entity, + hass, + flow["entity_energy_price"], + source_result, + ENERGY_PRICE_UNITS, + ENERGY_PRICE_UNIT_ERROR, + ) ) if flow.get("entity_energy_from") is not None and ( flow.get("entity_energy_price") is not None or flow.get("number_energy_price") is not None ): - _async_validate_auto_generated_cost_entity( - hass, - flow["entity_energy_from"], - source_result, + validate_calls.append( + functools.partial( + _async_validate_auto_generated_cost_entity, + hass, + flow["entity_energy_from"], + source_result, + ) ) for flow in source["flow_to"]: - await _async_validate_usage_stat( - hass, - flow["stat_energy_to"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(flow["stat_energy_to"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + flow["stat_energy_to"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + source_result, + ) ) if flow.get("stat_compensation") is not None: - await _async_validate_cost_stat( - hass, flow["stat_compensation"], source_result + wanted_statistics_metadata.add(flow["stat_compensation"]) + validate_calls.append( + functools.partial( + _async_validate_cost_stat, + hass, + statistics_metadata, + flow["stat_compensation"], + source_result, + ) ) elif flow.get("entity_energy_price") is not None: - _async_validate_price_entity( - hass, - flow["entity_energy_price"], - source_result, - ENERGY_PRICE_UNITS, - ENERGY_PRICE_UNIT_ERROR, + validate_calls.append( + functools.partial( + _async_validate_price_entity, + hass, + flow["entity_energy_price"], + source_result, + ENERGY_PRICE_UNITS, + ENERGY_PRICE_UNIT_ERROR, + ) ) if flow.get("entity_energy_to") is not None and ( flow.get("entity_energy_price") is not None or flow.get("number_energy_price") is not None ): - _async_validate_auto_generated_cost_entity( - hass, - flow["entity_energy_to"], - source_result, + validate_calls.append( + functools.partial( + _async_validate_auto_generated_cost_entity, + hass, + flow["entity_energy_to"], + source_result, + ) ) elif source["type"] == "gas": - await _async_validate_usage_stat( - hass, - source["stat_energy_from"], - GAS_USAGE_DEVICE_CLASSES, - GAS_USAGE_UNITS, - GAS_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(source["stat_energy_from"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + source["stat_energy_from"], + GAS_USAGE_DEVICE_CLASSES, + GAS_USAGE_UNITS, + GAS_UNIT_ERROR, + source_result, + ) ) if source.get("stat_cost") is not None: - await _async_validate_cost_stat( - hass, source["stat_cost"], source_result + wanted_statistics_metadata.add(source["stat_cost"]) + validate_calls.append( + functools.partial( + _async_validate_cost_stat, + hass, + statistics_metadata, + source["stat_cost"], + source_result, + ) ) elif source.get("entity_energy_price") is not None: - _async_validate_price_entity( - hass, - source["entity_energy_price"], - source_result, - GAS_PRICE_UNITS, - GAS_PRICE_UNIT_ERROR, + validate_calls.append( + functools.partial( + _async_validate_price_entity, + hass, + source["entity_energy_price"], + source_result, + GAS_PRICE_UNITS, + GAS_PRICE_UNIT_ERROR, + ) ) if source.get("entity_energy_from") is not None and ( source.get("entity_energy_price") is not None or source.get("number_energy_price") is not None ): - _async_validate_auto_generated_cost_entity( - hass, - source["entity_energy_from"], - source_result, + validate_calls.append( + functools.partial( + _async_validate_auto_generated_cost_entity, + hass, + source["entity_energy_from"], + source_result, + ) ) elif source["type"] == "solar": - await _async_validate_usage_stat( - hass, - source["stat_energy_from"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(source["stat_energy_from"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + source["stat_energy_from"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + source_result, + ) ) elif source["type"] == "battery": - await _async_validate_usage_stat( - hass, - source["stat_energy_from"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(source["stat_energy_from"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + source["stat_energy_from"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + source_result, + ) ) - await _async_validate_usage_stat( - hass, - source["stat_energy_to"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - source_result, + wanted_statistics_metadata.add(source["stat_energy_to"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + source["stat_energy_to"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + source_result, + ) ) for device in manager.data["device_consumption"]: device_result: list[ValidationIssue] = [] result.device_consumption.append(device_result) - await _async_validate_usage_stat( - hass, - device["stat_consumption"], - ENERGY_USAGE_DEVICE_CLASSES, - ENERGY_USAGE_UNITS, - ENERGY_UNIT_ERROR, - device_result, + wanted_statistics_metadata.add(device["stat_consumption"]) + validate_calls.append( + functools.partial( + _async_validate_usage_stat, + hass, + statistics_metadata, + device["stat_consumption"], + ENERGY_USAGE_DEVICE_CLASSES, + ENERGY_USAGE_UNITS, + ENERGY_UNIT_ERROR, + device_result, + ) ) + # Fetch the needed statistics metadata + statistics_metadata.update( + await hass.async_add_executor_job( + functools.partial( + recorder.statistics.get_metadata, + hass, + statistic_ids=list(wanted_statistics_metadata), + ) + ) + ) + + # Execute all the validation checks + for call in validate_calls: + call() + return result diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index e5fe84cc874..83de258ea5d 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -478,7 +478,7 @@ def get_metadata_with_session( hass: HomeAssistant, session: scoped_session, *, - statistic_ids: Iterable[str] | None = None, + statistic_ids: list[str] | tuple[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: @@ -533,7 +533,7 @@ def get_metadata_with_session( def get_metadata( hass: HomeAssistant, *, - statistic_ids: Iterable[str] | None = None, + statistic_ids: list[str] | tuple[str] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_source: str | None = None, ) -> dict[str, tuple[int, StatisticMetaData]]: diff --git a/tests/components/energy/test_validate.py b/tests/components/energy/test_validate.py index 5e3ad5c4aff..78a61b1bf69 100644 --- a/tests/components/energy/test_validate.py +++ b/tests/components/energy/test_validate.py @@ -26,11 +26,19 @@ def mock_get_metadata(): """Mock recorder.statistics.get_metadata.""" mocks = {} + def _get_metadata(_hass, *, statistic_ids): + result = {} + for statistic_id in statistic_ids: + if statistic_id in mocks: + if mocks[statistic_id] is not None: + result[statistic_id] = mocks[statistic_id] + else: + result[statistic_id] = (1, {}) + return result + with patch( "homeassistant.components.recorder.statistics.get_metadata", - side_effect=lambda hass, statistic_ids: mocks.get( - statistic_ids[0], {statistic_ids[0]: (1, {})} - ), + wraps=_get_metadata, ): yield mocks @@ -361,8 +369,8 @@ async def test_validation_grid( """Test validating grid with sensors for energy and cost/compensation.""" mock_is_entity_recorded["sensor.grid_cost_1"] = False mock_is_entity_recorded["sensor.grid_compensation_1"] = False - mock_get_metadata["sensor.grid_cost_1"] = {} - mock_get_metadata["sensor.grid_compensation_1"] = {} + mock_get_metadata["sensor.grid_cost_1"] = None + mock_get_metadata["sensor.grid_compensation_1"] = None await mock_energy_manager.async_update( { "energy_sources": [ @@ -456,8 +464,8 @@ async def test_validation_grid_external_cost_compensation( hass, mock_energy_manager, mock_is_entity_recorded, mock_get_metadata ): """Test validating grid with non entity stats for energy and cost/compensation.""" - mock_get_metadata["external:grid_cost_1"] = {} - mock_get_metadata["external:grid_compensation_1"] = {} + mock_get_metadata["external:grid_cost_1"] = None + mock_get_metadata["external:grid_compensation_1"] = None await mock_energy_manager.async_update( { "energy_sources": [