Refactor bayesian observations using dataclass (#79590)

* refactor

* remove some changes

* remove typehint

* improve codestyle

* move docstring to comment

* < 88 chars

* avoid short var names

* more readable

* fix rename

* Update homeassistant/components/bayesian/helpers.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update homeassistant/components/bayesian/binary_sensor.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* no intermediate

* comment why set before list

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
pull/79602/head
HarvsG 2022-10-04 16:16:39 +01:00 committed by GitHub
parent 56dd0a6867
commit dd1463da28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 195 additions and 105 deletions

View File

@ -35,24 +35,24 @@ from homeassistant.helpers.template import result_as_boolean
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import DOMAIN, PLATFORMS
from .const import (
ATTR_OBSERVATIONS,
ATTR_OCCURRED_OBSERVATION_ENTITIES,
ATTR_PROBABILITY,
ATTR_PROBABILITY_THRESHOLD,
CONF_OBSERVATIONS,
CONF_P_GIVEN_F,
CONF_P_GIVEN_T,
CONF_PRIOR,
CONF_PROBABILITY_THRESHOLD,
CONF_TEMPLATE,
CONF_TO_STATE,
DEFAULT_NAME,
DEFAULT_PROBABILITY_THRESHOLD,
)
from .helpers import Observation
from .repairs import raise_mirrored_entries, raise_no_prob_given_false
ATTR_OBSERVATIONS = "observations"
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
ATTR_PROBABILITY = "probability"
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_OBSERVATIONS = "observations"
CONF_PRIOR = "prior"
CONF_TEMPLATE = "template"
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_P_GIVEN_F = "prob_given_false"
CONF_P_GIVEN_T = "prob_given_true"
CONF_TO_STATE = "to_state"
DEFAULT_NAME = "Bayesian Binary Sensor"
DEFAULT_PROBABILITY_THRESHOLD = 0.5
_LOGGER = logging.getLogger(__name__)
@ -156,7 +156,20 @@ class BayesianBinarySensor(BinarySensorEntity):
def __init__(self, name, prior, observations, probability_threshold, device_class):
"""Initialize the Bayesian sensor."""
self._attr_name = name
self._observations = observations
self._observations = [
Observation(
entity_id=observation.get(CONF_ENTITY_ID),
platform=observation[CONF_PLATFORM],
prob_given_false=observation[CONF_P_GIVEN_F],
prob_given_true=observation[CONF_P_GIVEN_T],
observed=None,
to_state=observation.get(CONF_TO_STATE),
above=observation.get(CONF_ABOVE),
below=observation.get(CONF_BELOW),
value_template=observation.get(CONF_VALUE_TEMPLATE),
)
for observation in observations
]
self._probability_threshold = probability_threshold
self._attr_device_class = device_class
self._attr_is_on = False
@ -230,13 +243,18 @@ class BayesianBinarySensor(BinarySensorEntity):
self.entity_id,
)
observation = None
observed = None
else:
observation = result_as_boolean(result)
observed = result_as_boolean(result)
for obs in self.observations_by_template[template]:
obs_entry = {"entity_id": entity, "observation": observation, **obs}
self.current_observations[obs["id"]] = obs_entry
for observation in self.observations_by_template[template]:
observation.observed = observed
# in some cases a template may update because of the absence of an entity
if entity is not None:
observation.entity_id = str(entity)
self.current_observations[observation.id] = observation
if event:
self.async_set_context(event.context)
@ -270,7 +288,7 @@ class BayesianBinarySensor(BinarySensorEntity):
raise_mirrored_entries(
self.hass,
all_template_observations,
text=f"{self._attr_name}/{all_template_observations[0]['value_template']}",
text=f"{self._attr_name}/{all_template_observations[0].value_template}",
)
@callback
@ -289,42 +307,38 @@ class BayesianBinarySensor(BinarySensorEntity):
def _record_entity_observations(self, entity):
local_observations = OrderedDict({})
for entity_obs in self.observations_by_entity[entity]:
platform = entity_obs["platform"]
for observation in self.observations_by_entity[entity]:
platform = observation.platform
observation = self.observation_handlers[platform](entity_obs)
observed = self.observation_handlers[platform](observation)
observation.observed = observed
obs_entry = {
"entity_id": entity,
"observation": observation,
**entity_obs,
}
local_observations[entity_obs["id"]] = obs_entry
local_observations[observation.id] = observation
return local_observations
def _calculate_new_probability(self):
prior = self.prior
for obs in self.current_observations.values():
if obs is not None:
if obs["observation"] is True:
for observation in self.current_observations.values():
if observation is not None:
if observation.observed is True:
prior = update_probability(
prior,
obs["prob_given_true"],
obs["prob_given_false"],
observation.prob_given_true,
observation.prob_given_false,
)
elif obs["observation"] is False:
elif observation.observed is False:
prior = update_probability(
prior,
1 - obs["prob_given_true"],
1 - obs["prob_given_false"],
1 - observation.prob_given_true,
1 - observation.prob_given_false,
)
elif obs["observation"] is None:
if obs["entity_id"] is not None:
elif observation.observed is None:
if observation.entity_id is not None:
_LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
obs["entity_id"],
observation.entity_id,
)
else:
_LOGGER.debug(
@ -338,8 +352,8 @@ class BayesianBinarySensor(BinarySensorEntity):
Build and return data structure of the form below.
{
"sensor.sensor1": [{"id": 0, ...}, {"id": 1, ...}],
"sensor.sensor2": [{"id": 2, ...}],
"sensor.sensor1": [Observation, Observation],
"sensor.sensor2": [Observation],
...
}
@ -347,21 +361,20 @@ class BayesianBinarySensor(BinarySensorEntity):
for all relevant observations to be looked up via their `entity_id`.
"""
observations_by_entity: dict[str, list[OrderedDict]] = {}
for i, obs in enumerate(self._observations):
obs["id"] = i
observations_by_entity: dict[str, list[Observation]] = {}
for observation in self._observations:
if "entity_id" not in obs:
if (key := observation.entity_id) is None:
continue
observations_by_entity.setdefault(obs["entity_id"], []).append(obs)
observations_by_entity.setdefault(key, []).append(observation)
for li_of_dicts in observations_by_entity.values():
if len(li_of_dicts) == 1:
for entity_observations in observations_by_entity.values():
if len(entity_observations) == 1:
continue
for ord_dict in li_of_dicts:
if ord_dict["platform"] != "state":
for observation in entity_observations:
if observation.platform != "state":
continue
ord_dict["platform"] = "multi_state"
observation.platform = "multi_state"
return observations_by_entity
@ -370,8 +383,8 @@ class BayesianBinarySensor(BinarySensorEntity):
Build and return data structure of the form below.
{
"template": [{"id": 0, ...}, {"id": 1, ...}],
"template2": [{"id": 2, ...}],
"template": [Observation, Observation],
"template2": [Observation],
...
}
@ -380,20 +393,18 @@ class BayesianBinarySensor(BinarySensorEntity):
"""
observations_by_template = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
if "value_template" not in obs:
for observation in self._observations:
if observation.value_template is None:
continue
template = obs.get(CONF_VALUE_TEMPLATE)
observations_by_template.setdefault(template, []).append(obs)
template = observation.value_template
observations_by_template.setdefault(template, []).append(observation)
return observations_by_template
def _process_numeric_state(self, entity_observation):
"""Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation["entity_id"]
entity = entity_observation.entity_id
try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
@ -401,61 +412,67 @@ class BayesianBinarySensor(BinarySensorEntity):
return condition.async_numeric_state(
self.hass,
entity,
entity_observation.get("below"),
entity_observation.get("above"),
entity_observation.below,
entity_observation.above,
None,
entity_observation,
entity_observation.to_dict(),
)
except ConditionError:
return None
def _process_state(self, entity_observation):
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]
"""Return True if state conditions are met, return False if they are not.
Returns None if the state is unavailable.
"""
entity = entity_observation.entity_id
try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None
return condition.state(
self.hass, entity, entity_observation.get("to_state")
)
return condition.state(self.hass, entity, entity_observation.to_state)
except ConditionError:
return None
def _process_multi_state(self, entity_observation):
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]
"""Return True if state conditions are met, otherwise return None.
Never return False as all other states should have their own probabilities configured.
"""
entity = entity_observation.entity_id
try:
if condition.state(self.hass, entity, entity_observation.get("to_state")):
if condition.state(self.hass, entity, entity_observation.to_state):
return True
except ConditionError:
return None
return None
@property
def extra_state_attributes(self):
"""Return the state attributes of the sensor."""
attr_observations_list = [
obs.copy() for obs in self.current_observations.values() if obs is not None
]
for item in attr_observations_list:
item.pop("value_template", None)
return {
ATTR_OBSERVATIONS: attr_observations_list,
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
{
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None
and obs.get("entity_id") is not None
and obs.get("observation") is not None
}
),
ATTR_PROBABILITY: round(self.probability, 2),
ATTR_PROBABILITY_THRESHOLD: self._probability_threshold,
# An entity can be in more than one observation so set then list to deduplicate
ATTR_OCCURRED_OBSERVATION_ENTITIES: list(
{
observation.entity_id
for observation in self.current_observations.values()
if observation is not None
and observation.entity_id is not None
and observation.observed is not None
}
),
ATTR_OBSERVATIONS: [
observation.to_dict()
for observation in self.current_observations.values()
if observation is not None
],
}
async def async_update(self) -> None:

View File

@ -0,0 +1,17 @@
"""Consts for using in modules."""
ATTR_OBSERVATIONS = "observations"
ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities"
ATTR_PROBABILITY = "probability"
ATTR_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_OBSERVATIONS = "observations"
CONF_PRIOR = "prior"
CONF_TEMPLATE = "template"
CONF_PROBABILITY_THRESHOLD = "probability_threshold"
CONF_P_GIVEN_F = "prob_given_false"
CONF_P_GIVEN_T = "prob_given_true"
CONF_TO_STATE = "to_state"
DEFAULT_NAME = "Bayesian Binary Sensor"
DEFAULT_PROBABILITY_THRESHOLD = 0.5

View File

@ -0,0 +1,69 @@
"""Helpers to deal with bayesian observations."""
from __future__ import annotations
from dataclasses import dataclass, field
import uuid
from homeassistant.const import (
CONF_ABOVE,
CONF_BELOW,
CONF_ENTITY_ID,
CONF_PLATFORM,
CONF_VALUE_TEMPLATE,
)
from homeassistant.helpers.template import Template
from .const import CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_TO_STATE
@dataclass
class Observation:
"""Representation of a sensor or template observation."""
entity_id: str | None
platform: str
prob_given_true: float
prob_given_false: float
to_state: str | None
above: float | None
below: float | None
value_template: Template | None
observed: bool | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
def to_dict(self) -> dict[str, str | float | bool | None]:
"""Represent Class as a Dict for easier serialization."""
# Needed because dataclasses asdict() can't serialize Templates and ignores Properties.
dic = {
CONF_PLATFORM: self.platform,
CONF_ENTITY_ID: self.entity_id,
CONF_VALUE_TEMPLATE: self.template,
CONF_TO_STATE: self.to_state,
CONF_ABOVE: self.above,
CONF_BELOW: self.below,
CONF_P_GIVEN_T: self.prob_given_true,
CONF_P_GIVEN_F: self.prob_given_false,
"observed": self.observed,
}
for key, value in dic.copy().items():
if value is None:
del dic[key]
return dic
def is_mirror(self, other: Observation) -> bool:
"""Dectects whether given observation is a mirror of this one."""
return (
self.platform == other.platform
and round(self.prob_given_true + other.prob_given_true, 1) == 1
and round(self.prob_given_false + other.prob_given_false, 1) == 1
)
@property
def template(self) -> str | None:
"""Not all observations have templates and we want to get template strings."""
if self.value_template is not None:
return self.value_template.template
return None

View File

@ -11,20 +11,7 @@ def raise_mirrored_entries(hass: HomeAssistant, observations, text: str = "") ->
"""If there are mirrored entries, the user is probably using a workaround for a patched bug."""
if len(observations) != 2:
return
true_sums_1: bool = (
round(
observations[0]["prob_given_true"] + observations[1]["prob_given_true"], 1
)
== 1.0
)
false_sums_1: bool = (
round(
observations[0]["prob_given_false"] + observations[1]["prob_given_false"], 1
)
== 1.0
)
same_states: bool = observations[0]["platform"] == observations[1]["platform"]
if true_sums_1 & false_sums_1 & same_states:
if observations[0].is_mirror(observations[1]):
issue_registry.async_create_issue(
hass,
DOMAIN,