Minor refactor of energy validator (#58209)

pull/59530/head
Erik Montnemery 2021-11-11 07:38:15 +01:00 committed by GitHub
parent 5f8997471d
commit 65b1f0d9eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 204 additions and 113 deletions

View File

@ -67,8 +67,10 @@ class EnergyPreferencesValidation:
return dataclasses.asdict(self)
async def _async_validate_usage_stat(
@callback
def _async_validate_usage_stat(
hass: HomeAssistant,
metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]],
stat_id: str,
allowed_device_classes: Sequence[str],
allowed_units: Mapping[str, Sequence[str]],
@ -76,14 +78,6 @@ async def _async_validate_usage_stat(
result: list[ValidationIssue],
) -> None:
"""Validate a statistic."""
metadata = await hass.async_add_executor_job(
functools.partial(
recorder.statistics.get_metadata,
hass,
statistic_ids=(stat_id,),
)
)
if stat_id not in metadata:
result.append(ValidationIssue("statistics_not_defined", stat_id))
@ -201,18 +195,14 @@ def _async_validate_price_entity(
result.append(ValidationIssue(unit_error, entity_id, unit))
async def _async_validate_cost_stat(
hass: HomeAssistant, stat_id: str, result: list[ValidationIssue]
@callback
def _async_validate_cost_stat(
hass: HomeAssistant,
metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]],
stat_id: str,
result: list[ValidationIssue],
) -> None:
"""Validate that the cost stat is correct."""
metadata = await hass.async_add_executor_job(
functools.partial(
recorder.statistics.get_metadata,
hass,
statistic_ids=(stat_id,),
)
)
if stat_id not in metadata:
result.append(ValidationIssue("statistics_not_defined", stat_id))
@ -266,154 +256,247 @@ def _async_validate_auto_generated_cost_entity(
async def async_validate(hass: HomeAssistant) -> EnergyPreferencesValidation:
"""Validate the energy configuration."""
manager = await data.async_get_manager(hass)
statistics_metadata: dict[str, tuple[int, recorder.models.StatisticMetaData]] = {}
validate_calls = []
wanted_statistics_metadata = set()
result = EnergyPreferencesValidation()
if manager.data is None:
return result
# Create a list of validation checks
for source in manager.data["energy_sources"]:
source_result: list[ValidationIssue] = []
result.energy_sources.append(source_result)
if source["type"] == "grid":
for flow in source["flow_from"]:
await _async_validate_usage_stat(
wanted_statistics_metadata.add(flow["stat_energy_from"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
flow["stat_energy_from"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
source_result,
)
)
if flow.get("stat_cost") is not None:
await _async_validate_cost_stat(
hass, flow["stat_cost"], source_result
wanted_statistics_metadata.add(flow["stat_cost"])
validate_calls.append(
functools.partial(
_async_validate_cost_stat,
hass,
statistics_metadata,
flow["stat_cost"],
source_result,
)
)
elif flow.get("entity_energy_price") is not None:
_async_validate_price_entity(
validate_calls.append(
functools.partial(
_async_validate_price_entity,
hass,
flow["entity_energy_price"],
source_result,
ENERGY_PRICE_UNITS,
ENERGY_PRICE_UNIT_ERROR,
)
)
if flow.get("entity_energy_from") is not None and (
flow.get("entity_energy_price") is not None
or flow.get("number_energy_price") is not None
):
_async_validate_auto_generated_cost_entity(
validate_calls.append(
functools.partial(
_async_validate_auto_generated_cost_entity,
hass,
flow["entity_energy_from"],
source_result,
)
)
for flow in source["flow_to"]:
await _async_validate_usage_stat(
wanted_statistics_metadata.add(flow["stat_energy_to"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
flow["stat_energy_to"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
source_result,
)
)
if flow.get("stat_compensation") is not None:
await _async_validate_cost_stat(
hass, flow["stat_compensation"], source_result
wanted_statistics_metadata.add(flow["stat_compensation"])
validate_calls.append(
functools.partial(
_async_validate_cost_stat,
hass,
statistics_metadata,
flow["stat_compensation"],
source_result,
)
)
elif flow.get("entity_energy_price") is not None:
_async_validate_price_entity(
validate_calls.append(
functools.partial(
_async_validate_price_entity,
hass,
flow["entity_energy_price"],
source_result,
ENERGY_PRICE_UNITS,
ENERGY_PRICE_UNIT_ERROR,
)
)
if flow.get("entity_energy_to") is not None and (
flow.get("entity_energy_price") is not None
or flow.get("number_energy_price") is not None
):
_async_validate_auto_generated_cost_entity(
validate_calls.append(
functools.partial(
_async_validate_auto_generated_cost_entity,
hass,
flow["entity_energy_to"],
source_result,
)
)
elif source["type"] == "gas":
await _async_validate_usage_stat(
wanted_statistics_metadata.add(source["stat_energy_from"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
source["stat_energy_from"],
GAS_USAGE_DEVICE_CLASSES,
GAS_USAGE_UNITS,
GAS_UNIT_ERROR,
source_result,
)
)
if source.get("stat_cost") is not None:
await _async_validate_cost_stat(
hass, source["stat_cost"], source_result
wanted_statistics_metadata.add(source["stat_cost"])
validate_calls.append(
functools.partial(
_async_validate_cost_stat,
hass,
statistics_metadata,
source["stat_cost"],
source_result,
)
)
elif source.get("entity_energy_price") is not None:
_async_validate_price_entity(
validate_calls.append(
functools.partial(
_async_validate_price_entity,
hass,
source["entity_energy_price"],
source_result,
GAS_PRICE_UNITS,
GAS_PRICE_UNIT_ERROR,
)
)
if source.get("entity_energy_from") is not None and (
source.get("entity_energy_price") is not None
or source.get("number_energy_price") is not None
):
_async_validate_auto_generated_cost_entity(
validate_calls.append(
functools.partial(
_async_validate_auto_generated_cost_entity,
hass,
source["entity_energy_from"],
source_result,
)
)
elif source["type"] == "solar":
await _async_validate_usage_stat(
wanted_statistics_metadata.add(source["stat_energy_from"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
source["stat_energy_from"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
source_result,
)
)
elif source["type"] == "battery":
await _async_validate_usage_stat(
wanted_statistics_metadata.add(source["stat_energy_from"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
source["stat_energy_from"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
source_result,
)
await _async_validate_usage_stat(
)
wanted_statistics_metadata.add(source["stat_energy_to"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
source["stat_energy_to"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
source_result,
)
)
for device in manager.data["device_consumption"]:
device_result: list[ValidationIssue] = []
result.device_consumption.append(device_result)
await _async_validate_usage_stat(
wanted_statistics_metadata.add(device["stat_consumption"])
validate_calls.append(
functools.partial(
_async_validate_usage_stat,
hass,
statistics_metadata,
device["stat_consumption"],
ENERGY_USAGE_DEVICE_CLASSES,
ENERGY_USAGE_UNITS,
ENERGY_UNIT_ERROR,
device_result,
)
)
# Fetch the needed statistics metadata
statistics_metadata.update(
await hass.async_add_executor_job(
functools.partial(
recorder.statistics.get_metadata,
hass,
statistic_ids=list(wanted_statistics_metadata),
)
)
)
# Execute all the validation checks
for call in validate_calls:
call()
return result

View File

@ -478,7 +478,7 @@ def get_metadata_with_session(
hass: HomeAssistant,
session: scoped_session,
*,
statistic_ids: Iterable[str] | None = None,
statistic_ids: list[str] | tuple[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:
@ -533,7 +533,7 @@ def get_metadata_with_session(
def get_metadata(
hass: HomeAssistant,
*,
statistic_ids: Iterable[str] | None = None,
statistic_ids: list[str] | tuple[str] | None = None,
statistic_type: Literal["mean"] | Literal["sum"] | None = None,
statistic_source: str | None = None,
) -> dict[str, tuple[int, StatisticMetaData]]:

View File

@ -26,11 +26,19 @@ def mock_get_metadata():
"""Mock recorder.statistics.get_metadata."""
mocks = {}
def _get_metadata(_hass, *, statistic_ids):
result = {}
for statistic_id in statistic_ids:
if statistic_id in mocks:
if mocks[statistic_id] is not None:
result[statistic_id] = mocks[statistic_id]
else:
result[statistic_id] = (1, {})
return result
with patch(
"homeassistant.components.recorder.statistics.get_metadata",
side_effect=lambda hass, statistic_ids: mocks.get(
statistic_ids[0], {statistic_ids[0]: (1, {})}
),
wraps=_get_metadata,
):
yield mocks
@ -361,8 +369,8 @@ async def test_validation_grid(
"""Test validating grid with sensors for energy and cost/compensation."""
mock_is_entity_recorded["sensor.grid_cost_1"] = False
mock_is_entity_recorded["sensor.grid_compensation_1"] = False
mock_get_metadata["sensor.grid_cost_1"] = {}
mock_get_metadata["sensor.grid_compensation_1"] = {}
mock_get_metadata["sensor.grid_cost_1"] = None
mock_get_metadata["sensor.grid_compensation_1"] = None
await mock_energy_manager.async_update(
{
"energy_sources": [
@ -456,8 +464,8 @@ async def test_validation_grid_external_cost_compensation(
hass, mock_energy_manager, mock_is_entity_recorded, mock_get_metadata
):
"""Test validating grid with non entity stats for energy and cost/compensation."""
mock_get_metadata["external:grid_cost_1"] = {}
mock_get_metadata["external:grid_compensation_1"] = {}
mock_get_metadata["external:grid_cost_1"] = None
mock_get_metadata["external:grid_compensation_1"] = None
await mock_energy_manager.async_update(
{
"energy_sources": [