"""Use Bayesian Inference to trigger a binary sensor.""" from __future__ import annotations from collections import OrderedDict from collections.abc import Callable import logging from typing import Any from uuid import UUID import voluptuous as vol from homeassistant.components.binary_sensor import ( PLATFORM_SCHEMA, BinarySensorDeviceClass, BinarySensorEntity, ) from homeassistant.const import ( CONF_ABOVE, CONF_BELOW, CONF_DEVICE_CLASS, CONF_ENTITY_ID, CONF_NAME, CONF_PLATFORM, CONF_STATE, CONF_UNIQUE_ID, CONF_VALUE_TEMPLATE, STATE_UNAVAILABLE, STATE_UNKNOWN, ) from homeassistant.core import Event, HomeAssistant, callback from homeassistant.exceptions import ConditionError, TemplateError from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import ( TrackTemplate, TrackTemplateResult, TrackTemplateResultInfo, async_track_state_change_event, async_track_template_result, ) from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.template import Template, result_as_boolean from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import DOMAIN, PLATFORMS from .const import ( ATTR_OBSERVATIONS, ATTR_OCCURRED_OBSERVATION_ENTITIES, ATTR_PROBABILITY, ATTR_PROBABILITY_THRESHOLD, CONF_OBSERVATIONS, CONF_P_GIVEN_F, CONF_P_GIVEN_T, CONF_PRIOR, CONF_PROBABILITY_THRESHOLD, CONF_TEMPLATE, CONF_TO_STATE, DEFAULT_NAME, DEFAULT_PROBABILITY_THRESHOLD, ) from .helpers import Observation from .issues import raise_mirrored_entries, raise_no_prob_given_false _LOGGER = logging.getLogger(__name__) NUMERIC_STATE_SCHEMA = vol.Schema( { CONF_PLATFORM: "numeric_state", vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Optional(CONF_ABOVE): vol.Coerce(float), vol.Optional(CONF_BELOW): vol.Coerce(float), vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) STATE_SCHEMA = vol.Schema( { CONF_PLATFORM: CONF_STATE, vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Required(CONF_TO_STATE): cv.string, vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) TEMPLATE_SCHEMA = vol.Schema( { CONF_PLATFORM: CONF_TEMPLATE, vol.Required(CONF_VALUE_TEMPLATE): cv.template, vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_DEVICE_CLASS): cv.string, vol.Required(CONF_OBSERVATIONS): vol.Schema( vol.All( cv.ensure_list, [vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA, TEMPLATE_SCHEMA)], ) ), vol.Required(CONF_PRIOR): vol.Coerce(float), vol.Optional( CONF_PROBABILITY_THRESHOLD, default=DEFAULT_PROBABILITY_THRESHOLD ): vol.Coerce(float), } ) def update_probability( prior: float, prob_given_true: float, prob_given_false: float ) -> float: """Update probability using Bayes' rule.""" numerator = prob_given_true * prior denominator = numerator + prob_given_false * (1 - prior) return numerator / denominator async def async_setup_platform( hass: HomeAssistant, config: ConfigType, async_add_entities: AddEntitiesCallback, discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up the Bayesian Binary sensor.""" await async_setup_reload_service(hass, DOMAIN, PLATFORMS) name: str = config[CONF_NAME] unique_id: str | None = config.get(CONF_UNIQUE_ID) observations: list[ConfigType] = config[CONF_OBSERVATIONS] prior: float = config[CONF_PRIOR] probability_threshold: float = config[CONF_PROBABILITY_THRESHOLD] device_class: BinarySensorDeviceClass | None = config.get(CONF_DEVICE_CLASS) # Should deprecate in some future version (2022.10 at time of writing) & make prob_given_false required in schemas. broken_observations: list[dict[str, Any]] = [] for observation in observations: if CONF_P_GIVEN_F not in observation: text: str = f"{name}/{observation.get(CONF_ENTITY_ID,'')}{observation.get(CONF_VALUE_TEMPLATE,'')}" raise_no_prob_given_false(hass, text) _LOGGER.error("Missing prob_given_false YAML entry for %s", text) broken_observations.append(observation) observations = [x for x in observations if x not in broken_observations] async_add_entities( [ BayesianBinarySensor( name, unique_id, prior, observations, probability_threshold, device_class, ) ] ) class BayesianBinarySensor(BinarySensorEntity): """Representation of a Bayesian sensor.""" _attr_should_poll = False def __init__( self, name: str, unique_id: str | None, prior: float, observations: list[ConfigType], probability_threshold: float, device_class: BinarySensorDeviceClass | None, ) -> None: """Initialize the Bayesian sensor.""" self._attr_name = name self._attr_unique_id = unique_id and f"bayesian-{unique_id}" self._observations = [ Observation( entity_id=observation.get(CONF_ENTITY_ID), platform=observation[CONF_PLATFORM], prob_given_false=observation[CONF_P_GIVEN_F], prob_given_true=observation[CONF_P_GIVEN_T], observed=None, to_state=observation.get(CONF_TO_STATE), above=observation.get(CONF_ABOVE), below=observation.get(CONF_BELOW), value_template=observation.get(CONF_VALUE_TEMPLATE), ) for observation in observations ] self._probability_threshold = probability_threshold self._attr_device_class = device_class self._attr_is_on = False self._callbacks: list[TrackTemplateResultInfo] = [] self.prior = prior self.probability = prior self.current_observations: OrderedDict[UUID, Observation] = OrderedDict({}) self.observations_by_entity = self._build_observations_by_entity() self.observations_by_template = self._build_observations_by_template() self.observation_handlers: dict[str, Callable[[Observation], bool | None]] = { "numeric_state": self._process_numeric_state, "state": self._process_state, "multi_state": self._process_multi_state, } async def async_added_to_hass(self) -> None: """Call when entity about to be added. All relevant update logic for instance attributes occurs within this closure. Other methods in this class are designed to avoid directly modifying instance attributes, by instead focusing on returning relevant data back to this method. The goal of this method is to ensure that `self.current_observations` and `self.probability` are set on a best-effort basis when this entity is register with hass. In addition, this method must register the state listener defined within, which will be called any time a relevant entity changes its state. """ @callback def async_threshold_sensor_state_listener(event: Event) -> None: """Handle sensor state changes. When a state changes, we must update our list of current observations, then calculate the new probability. """ entity: str = event.data[CONF_ENTITY_ID] self.current_observations.update(self._record_entity_observations(entity)) self.async_set_context(event.context) self._recalculate_and_write_state() self.async_on_remove( async_track_state_change_event( self.hass, list(self.observations_by_entity), async_threshold_sensor_state_listener, ) ) @callback def _async_template_result_changed( event: Event | None, updates: list[TrackTemplateResult] ) -> None: track_template_result = updates.pop() template = track_template_result.template result = track_template_result.result entity: str | None = ( None if event is None else event.data.get(CONF_ENTITY_ID) ) if isinstance(result, TemplateError): _LOGGER.error( "TemplateError('%s') while processing template '%s' in entity '%s'", result, template, self.entity_id, ) observed = None else: observed = result_as_boolean(result) for observation in self.observations_by_template[template]: observation.observed = observed # in some cases a template may update because of the absence of an entity if entity is not None: observation.entity_id = entity self.current_observations[observation.id] = observation if event: self.async_set_context(event.context) self._recalculate_and_write_state() for template in self.observations_by_template: info = async_track_template_result( self.hass, [TrackTemplate(template, None)], _async_template_result_changed, ) self._callbacks.append(info) self.async_on_remove(info.async_remove) info.async_refresh() self.current_observations.update(self._initialize_current_observations()) self.probability = self._calculate_new_probability() self._attr_is_on = self.probability >= self._probability_threshold # detect mirrored entries for entity, observations in self.observations_by_entity.items(): raise_mirrored_entries( self.hass, observations, text=f"{self._attr_name}/{entity}" ) all_template_observations: list[Observation] = [] for observations in self.observations_by_template.values(): all_template_observations.append(observations[0]) if len(all_template_observations) == 2: raise_mirrored_entries( self.hass, all_template_observations, text=f"{self._attr_name}/{all_template_observations[0].value_template}", ) @callback def _recalculate_and_write_state(self) -> None: self.probability = self._calculate_new_probability() self._attr_is_on = bool(self.probability >= self._probability_threshold) self.async_write_ha_state() def _initialize_current_observations(self) -> OrderedDict[UUID, Observation]: local_observations: OrderedDict[UUID, Observation] = OrderedDict({}) for entity in self.observations_by_entity: local_observations.update(self._record_entity_observations(entity)) return local_observations def _record_entity_observations( self, entity: str ) -> OrderedDict[UUID, Observation]: local_observations: OrderedDict[UUID, Observation] = OrderedDict({}) for observation in self.observations_by_entity[entity]: platform = observation.platform observation.observed = self.observation_handlers[platform](observation) local_observations[observation.id] = observation return local_observations def _calculate_new_probability(self) -> float: prior = self.prior for observation in self.current_observations.values(): if observation.observed is True: prior = update_probability( prior, observation.prob_given_true, observation.prob_given_false, ) continue if observation.observed is False: prior = update_probability( prior, 1 - observation.prob_given_true, 1 - observation.prob_given_false, ) continue # observation.observed is None if observation.entity_id is not None: _LOGGER.debug( ( "Observation for entity '%s' returned None, it will not be used" " for Bayesian updating" ), observation.entity_id, ) continue _LOGGER.debug( ( "Observation for template entity returned None rather than a valid" " boolean, it will not be used for Bayesian updating" ), ) # the prior has been updated and is now the posterior return prior def _build_observations_by_entity(self) -> dict[str, list[Observation]]: """Build and return data structure of the form below. { "sensor.sensor1": [Observation, Observation], "sensor.sensor2": [Observation], ... } Each "observation" must be recognized uniquely, and it should be possible for all relevant observations to be looked up via their `entity_id`. """ observations_by_entity: dict[str, list[Observation]] = {} for observation in self._observations: if (key := observation.entity_id) is None: continue observations_by_entity.setdefault(key, []).append(observation) for entity_observations in observations_by_entity.values(): if len(entity_observations) == 1: continue for observation in entity_observations: if observation.platform != "state": continue observation.platform = "multi_state" return observations_by_entity def _build_observations_by_template(self) -> dict[Template, list[Observation]]: """Build and return data structure of the form below. { "template": [Observation, Observation], "template2": [Observation], ... } Each "observation" must be recognized uniquely, and it should be possible for all relevant observations to be looked up via their `template`. """ observations_by_template: dict[Template, list[Observation]] = {} for observation in self._observations: if observation.value_template is None: continue template = observation.value_template observations_by_template.setdefault(template, []).append(observation) return observations_by_template def _process_numeric_state(self, entity_observation: Observation) -> bool | None: """Return True if numeric condition is met, return False if not, return None otherwise.""" entity = entity_observation.entity_id try: if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]): return None return condition.async_numeric_state( self.hass, entity, entity_observation.below, entity_observation.above, None, entity_observation.to_dict(), ) except ConditionError: return None def _process_state(self, entity_observation: Observation) -> bool | None: """Return True if state conditions are met, return False if they are not. Returns None if the state is unavailable. """ entity = entity_observation.entity_id try: if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]): return None return condition.state(self.hass, entity, entity_observation.to_state) except ConditionError: return None def _process_multi_state(self, entity_observation: Observation) -> bool | None: """Return True if state conditions are met, otherwise return None. Never return False as all other states should have their own probabilities configured. """ entity = entity_observation.entity_id try: if condition.state(self.hass, entity, entity_observation.to_state): return True except ConditionError: return None return None @property def extra_state_attributes(self) -> dict[str, Any]: """Return the state attributes of the sensor.""" return { ATTR_PROBABILITY: round(self.probability, 2), ATTR_PROBABILITY_THRESHOLD: self._probability_threshold, # An entity can be in more than one observation so set then list to deduplicate ATTR_OCCURRED_OBSERVATION_ENTITIES: list( { observation.entity_id for observation in self.current_observations.values() if observation is not None and observation.entity_id is not None and observation.observed is not None } ), ATTR_OBSERVATIONS: [ observation.to_dict() for observation in self.current_observations.values() if observation is not None ], } async def async_update(self) -> None: """Get the latest data and update the states.""" if not self._callbacks: self._recalculate_and_write_state() return # Force recalc of the templates. The states will # update automatically. for call in self._callbacks: call.async_refresh()