Refactor bayesian observations using dataclass (#79590)
* refactor * remove some changes * remove typehint * improve codestyle * move docstring to comment * < 88 chars * avoid short var names * more readable * fix rename * Update homeassistant/components/bayesian/helpers.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update homeassistant/components/bayesian/binary_sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update homeassistant/components/bayesian/binary_sensor.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * no intermediate * comment why set before list Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>pull/79602/head
parent
56dd0a6867
commit
dd1463da28
|
@ -35,24 +35,24 @@ from homeassistant.helpers.template import result_as_boolean
|
|||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import DOMAIN, PLATFORMS
|
||||
from .const import (
|
||||
ATTR_OBSERVATIONS,
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES,
|
||||
ATTR_PROBABILITY,
|
||||
ATTR_PROBABILITY_THRESHOLD,
|
||||
CONF_OBSERVATIONS,
|
||||
CONF_P_GIVEN_F,
|
||||
CONF_P_GIVEN_T,
|
||||
CONF_PRIOR,
|
||||
CONF_PROBABILITY_THRESHOLD,
|
||||
CONF_TEMPLATE,
|
||||
CONF_TO_STATE,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_PROBABILITY_THRESHOLD,
|
||||
)
|
||||
from .helpers import Observation
|
||||
from .repairs import raise_mirrored_entries, raise_no_prob_given_false
|
||||
|
||||
ATTR_OBSERVATIONS = "observations"
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
||||
ATTR_PROBABILITY = "probability"
|
||||
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||
|
||||
CONF_OBSERVATIONS = "observations"
|
||||
CONF_PRIOR = "prior"
|
||||
CONF_TEMPLATE = "template"
|
||||
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||
CONF_P_GIVEN_F = "prob_given_false"
|
||||
CONF_P_GIVEN_T = "prob_given_true"
|
||||
CONF_TO_STATE = "to_state"
|
||||
|
||||
DEFAULT_NAME = "Bayesian Binary Sensor"
|
||||
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -156,7 +156,20 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
def __init__(self, name, prior, observations, probability_threshold, device_class):
|
||||
"""Initialize the Bayesian sensor."""
|
||||
self._attr_name = name
|
||||
self._observations = observations
|
||||
self._observations = [
|
||||
Observation(
|
||||
entity_id=observation.get(CONF_ENTITY_ID),
|
||||
platform=observation[CONF_PLATFORM],
|
||||
prob_given_false=observation[CONF_P_GIVEN_F],
|
||||
prob_given_true=observation[CONF_P_GIVEN_T],
|
||||
observed=None,
|
||||
to_state=observation.get(CONF_TO_STATE),
|
||||
above=observation.get(CONF_ABOVE),
|
||||
below=observation.get(CONF_BELOW),
|
||||
value_template=observation.get(CONF_VALUE_TEMPLATE),
|
||||
)
|
||||
for observation in observations
|
||||
]
|
||||
self._probability_threshold = probability_threshold
|
||||
self._attr_device_class = device_class
|
||||
self._attr_is_on = False
|
||||
|
@ -230,13 +243,18 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
self.entity_id,
|
||||
)
|
||||
|
||||
observation = None
|
||||
observed = None
|
||||
else:
|
||||
observation = result_as_boolean(result)
|
||||
observed = result_as_boolean(result)
|
||||
|
||||
for obs in self.observations_by_template[template]:
|
||||
obs_entry = {"entity_id": entity, "observation": observation, **obs}
|
||||
self.current_observations[obs["id"]] = obs_entry
|
||||
for observation in self.observations_by_template[template]:
|
||||
observation.observed = observed
|
||||
|
||||
# in some cases a template may update because of the absence of an entity
|
||||
if entity is not None:
|
||||
observation.entity_id = str(entity)
|
||||
|
||||
self.current_observations[observation.id] = observation
|
||||
|
||||
if event:
|
||||
self.async_set_context(event.context)
|
||||
|
@ -270,7 +288,7 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
raise_mirrored_entries(
|
||||
self.hass,
|
||||
all_template_observations,
|
||||
text=f"{self._attr_name}/{all_template_observations[0]['value_template']}",
|
||||
text=f"{self._attr_name}/{all_template_observations[0].value_template}",
|
||||
)
|
||||
|
||||
@callback
|
||||
|
@ -289,42 +307,38 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
def _record_entity_observations(self, entity):
|
||||
local_observations = OrderedDict({})
|
||||
|
||||
for entity_obs in self.observations_by_entity[entity]:
|
||||
platform = entity_obs["platform"]
|
||||
for observation in self.observations_by_entity[entity]:
|
||||
platform = observation.platform
|
||||
|
||||
observation = self.observation_handlers[platform](entity_obs)
|
||||
observed = self.observation_handlers[platform](observation)
|
||||
observation.observed = observed
|
||||
|
||||
obs_entry = {
|
||||
"entity_id": entity,
|
||||
"observation": observation,
|
||||
**entity_obs,
|
||||
}
|
||||
local_observations[entity_obs["id"]] = obs_entry
|
||||
local_observations[observation.id] = observation
|
||||
|
||||
return local_observations
|
||||
|
||||
def _calculate_new_probability(self):
|
||||
prior = self.prior
|
||||
|
||||
for obs in self.current_observations.values():
|
||||
if obs is not None:
|
||||
if obs["observation"] is True:
|
||||
for observation in self.current_observations.values():
|
||||
if observation is not None:
|
||||
if observation.observed is True:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
obs["prob_given_true"],
|
||||
obs["prob_given_false"],
|
||||
observation.prob_given_true,
|
||||
observation.prob_given_false,
|
||||
)
|
||||
elif obs["observation"] is False:
|
||||
elif observation.observed is False:
|
||||
prior = update_probability(
|
||||
prior,
|
||||
1 - obs["prob_given_true"],
|
||||
1 - obs["prob_given_false"],
|
||||
1 - observation.prob_given_true,
|
||||
1 - observation.prob_given_false,
|
||||
)
|
||||
elif obs["observation"] is None:
|
||||
if obs["entity_id"] is not None:
|
||||
elif observation.observed is None:
|
||||
if observation.entity_id is not None:
|
||||
_LOGGER.debug(
|
||||
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
|
||||
obs["entity_id"],
|
||||
observation.entity_id,
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
|
@ -338,8 +352,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
Build and return data structure of the form below.
|
||||
|
||||
{
|
||||
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
|
||||
"sensor.sensor2": [{"id": 2, ...}],
|
||||
"sensor.sensor1": [Observation, Observation],
|
||||
"sensor.sensor2": [Observation],
|
||||
...
|
||||
}
|
||||
|
||||
|
@ -347,21 +361,20 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
for all relevant observations to be looked up via their `entity_id`.
|
||||
"""
|
||||
|
||||
observations_by_entity: dict[str, list[OrderedDict]] = {}
|
||||
for i, obs in enumerate(self._observations):
|
||||
obs["id"] = i
|
||||
observations_by_entity: dict[str, list[Observation]] = {}
|
||||
for observation in self._observations:
|
||||
|
||||
if "entity_id" not in obs:
|
||||
if (key := observation.entity_id) is None:
|
||||
continue
|
||||
observations_by_entity.setdefault(obs["entity_id"], []).append(obs)
|
||||
observations_by_entity.setdefault(key, []).append(observation)
|
||||
|
||||
for li_of_dicts in observations_by_entity.values():
|
||||
if len(li_of_dicts) == 1:
|
||||
for entity_observations in observations_by_entity.values():
|
||||
if len(entity_observations) == 1:
|
||||
continue
|
||||
for ord_dict in li_of_dicts:
|
||||
if ord_dict["platform"] != "state":
|
||||
for observation in entity_observations:
|
||||
if observation.platform != "state":
|
||||
continue
|
||||
ord_dict["platform"] = "multi_state"
|
||||
observation.platform = "multi_state"
|
||||
|
||||
return observations_by_entity
|
||||
|
||||
|
@ -370,8 +383,8 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
Build and return data structure of the form below.
|
||||
|
||||
{
|
||||
"template": [{"id": 0, ...}, {"id": 1, ...}],
|
||||
"template2": [{"id": 2, ...}],
|
||||
"template": [Observation, Observation],
|
||||
"template2": [Observation],
|
||||
...
|
||||
}
|
||||
|
||||
|
@ -380,20 +393,18 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
"""
|
||||
|
||||
observations_by_template = {}
|
||||
for ind, obs in enumerate(self._observations):
|
||||
obs["id"] = ind
|
||||
|
||||
if "value_template" not in obs:
|
||||
for observation in self._observations:
|
||||
if observation.value_template is None:
|
||||
continue
|
||||
|
||||
template = obs.get(CONF_VALUE_TEMPLATE)
|
||||
observations_by_template.setdefault(template, []).append(obs)
|
||||
template = observation.value_template
|
||||
observations_by_template.setdefault(template, []).append(observation)
|
||||
|
||||
return observations_by_template
|
||||
|
||||
def _process_numeric_state(self, entity_observation):
|
||||
"""Return True if numeric condition is met, return False if not, return None otherwise."""
|
||||
entity = entity_observation["entity_id"]
|
||||
entity = entity_observation.entity_id
|
||||
|
||||
try:
|
||||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||
|
@ -401,61 +412,67 @@ class BayesianBinarySensor(BinarySensorEntity):
|
|||
return condition.async_numeric_state(
|
||||
self.hass,
|
||||
entity,
|
||||
entity_observation.get("below"),
|
||||
entity_observation.get("above"),
|
||||
entity_observation.below,
|
||||
entity_observation.above,
|
||||
None,
|
||||
entity_observation,
|
||||
entity_observation.to_dict(),
|
||||
)
|
||||
except ConditionError:
|
||||
return None
|
||||
|
||||
def _process_state(self, entity_observation):
|
||||
"""Return True if state conditions are met."""
|
||||
entity = entity_observation["entity_id"]
|
||||
"""Return True if state conditions are met, return False if they are not.
|
||||
|
||||
Returns None if the state is unavailable.
|
||||
"""
|
||||
|
||||
entity = entity_observation.entity_id
|
||||
|
||||
try:
|
||||
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
|
||||
return None
|
||||
|
||||
return condition.state(
|
||||
self.hass, entity, entity_observation.get("to_state")
|
||||
)
|
||||
return condition.state(self.hass, entity, entity_observation.to_state)
|
||||
except ConditionError:
|
||||
return None
|
||||
|
||||
def _process_multi_state(self, entity_observation):
|
||||
"""Return True if state conditions are met."""
|
||||
entity = entity_observation["entity_id"]
|
||||
"""Return True if state conditions are met, otherwise return None.
|
||||
|
||||
Never return False as all other states should have their own probabilities configured.
|
||||
"""
|
||||
|
||||
entity = entity_observation.entity_id
|
||||
|
||||
try:
|
||||
if condition.state(self.hass, entity, entity_observation.get("to_state")):
|
||||
if condition.state(self.hass, entity, entity_observation.to_state):
|
||||
return True
|
||||
except ConditionError:
|
||||
return None
|
||||
return None
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self):
|
||||
"""Return the state attributes of the sensor."""
|
||||
attr_observations_list = [
|
||||
obs.copy() for obs in self.current_observations.values() if obs is not None
|
||||
]
|
||||
|
||||
for item in attr_observations_list:
|
||||
item.pop("value_template", None)
|
||||
|
||||
return {
|
||||
ATTR_OBSERVATIONS: attr_observations_list,
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
|
||||
{
|
||||
obs.get("entity_id")
|
||||
for obs in self.current_observations.values()
|
||||
if obs is not None
|
||||
and obs.get("entity_id") is not None
|
||||
and obs.get("observation") is not None
|
||||
}
|
||||
),
|
||||
ATTR_PROBABILITY: round(self.probability, 2),
|
||||
ATTR_PROBABILITY_THRESHOLD: self._probability_threshold,
|
||||
# An entity can be in more than one observation so set then list to deduplicate
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
|
||||
{
|
||||
observation.entity_id
|
||||
for observation in self.current_observations.values()
|
||||
if observation is not None
|
||||
and observation.entity_id is not None
|
||||
and observation.observed is not None
|
||||
}
|
||||
),
|
||||
ATTR_OBSERVATIONS: [
|
||||
observation.to_dict()
|
||||
for observation in self.current_observations.values()
|
||||
if observation is not None
|
||||
],
|
||||
}
|
||||
|
||||
async def async_update(self) -> None:
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
"""Consts for using in modules."""
|
||||
|
||||
ATTR_OBSERVATIONS = "observations"
|
||||
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
|
||||
ATTR_PROBABILITY = "probability"
|
||||
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||
|
||||
CONF_OBSERVATIONS = "observations"
|
||||
CONF_PRIOR = "prior"
|
||||
CONF_TEMPLATE = "template"
|
||||
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
|
||||
CONF_P_GIVEN_F = "prob_given_false"
|
||||
CONF_P_GIVEN_T = "prob_given_true"
|
||||
CONF_TO_STATE = "to_state"
|
||||
|
||||
DEFAULT_NAME = "Bayesian Binary Sensor"
|
||||
DEFAULT_PROBABILITY_THRESHOLD = 0.5
|
|
@ -0,0 +1,69 @@
|
|||
"""Helpers to deal with bayesian observations."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
|
||||
from homeassistant.const import (
|
||||
CONF_ABOVE,
|
||||
CONF_BELOW,
|
||||
CONF_ENTITY_ID,
|
||||
CONF_PLATFORM,
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.helpers.template import Template
|
||||
|
||||
from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
class Observation:
|
||||
"""Representation of a sensor or template observation."""
|
||||
|
||||
entity_id: str | None
|
||||
platform: str
|
||||
prob_given_true: float
|
||||
prob_given_false: float
|
||||
to_state: str | None
|
||||
above: float | None
|
||||
below: float | None
|
||||
value_template: Template | None
|
||||
observed: bool | None = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
def to_dict(self) -> dict[str, str | float | bool | None]:
|
||||
"""Represent Class as a Dict for easier serialization."""
|
||||
|
||||
# Needed because dataclasses asdict() can't serialize Templates and ignores Properties.
|
||||
dic = {
|
||||
CONF_PLATFORM: self.platform,
|
||||
CONF_ENTITY_ID: self.entity_id,
|
||||
CONF_VALUE_TEMPLATE: self.template,
|
||||
CONF_TO_STATE: self.to_state,
|
||||
CONF_ABOVE: self.above,
|
||||
CONF_BELOW: self.below,
|
||||
CONF_P_GIVEN_T: self.prob_given_true,
|
||||
CONF_P_GIVEN_F: self.prob_given_false,
|
||||
"observed": self.observed,
|
||||
}
|
||||
|
||||
for key, value in dic.copy().items():
|
||||
if value is None:
|
||||
del dic[key]
|
||||
|
||||
return dic
|
||||
|
||||
def is_mirror(self, other: Observation) -> bool:
|
||||
"""Dectects whether given observation is a mirror of this one."""
|
||||
return (
|
||||
self.platform == other.platform
|
||||
and round(self.prob_given_true + other.prob_given_true, 1) == 1
|
||||
and round(self.prob_given_false + other.prob_given_false, 1) == 1
|
||||
)
|
||||
|
||||
@property
|
||||
def template(self) -> str | None:
|
||||
"""Not all observations have templates and we want to get template strings."""
|
||||
if self.value_template is not None:
|
||||
return self.value_template.template
|
||||
return None
|
|
@ -11,20 +11,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
|
|||
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
|
||||
if len(observations) != 2:
|
||||
return
|
||||
true_sums_1: bool = (
|
||||
round(
|
||||
observations[0]["prob_given_true"] + observations[1]["prob_given_true"], 1
|
||||
)
|
||||
== 1.0
|
||||
)
|
||||
false_sums_1: bool = (
|
||||
round(
|
||||
observations[0]["prob_given_false"] + observations[1]["prob_given_false"], 1
|
||||
)
|
||||
== 1.0
|
||||
)
|
||||
same_states: bool = observations[0]["platform"] == observations[1]["platform"]
|
||||
if true_sums_1 & false_sums_1 & same_states:
|
||||
if observations[0].is_mirror(observations[1]):
|
||||
issue_registry.async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
|
|
Loading…
Reference in New Issue