"""Use Bayesian Inference to trigger a binary sensor.""" from collections import OrderedDict from itertools import chain import voluptuous as vol from homeassistant.components.binary_sensor import PLATFORM_SCHEMA, BinarySensorDevice from homeassistant.const import ( CONF_ABOVE, CONF_BELOW, CONF_DEVICE_CLASS, CONF_ENTITY_ID, CONF_NAME, CONF_PLATFORM, CONF_STATE, CONF_VALUE_TEMPLATE, STATE_UNKNOWN, ) from homeassistant.core import callback from homeassistant.helpers import condition import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_state_change ATTR_OBSERVATIONS = "observations" ATTR_OCCURRED_OBSERVATION_ENTITIES = "occurred_observation_entities" ATTR_PROBABILITY = "probability" ATTR_PROBABILITY_THRESHOLD = "probability_threshold" CONF_OBSERVATIONS = "observations" CONF_PRIOR = "prior" CONF_TEMPLATE = "template" CONF_PROBABILITY_THRESHOLD = "probability_threshold" CONF_P_GIVEN_F = "prob_given_false" CONF_P_GIVEN_T = "prob_given_true" CONF_TO_STATE = "to_state" DEFAULT_NAME = "Bayesian Binary Sensor" DEFAULT_PROBABILITY_THRESHOLD = 0.5 NUMERIC_STATE_SCHEMA = vol.Schema( { CONF_PLATFORM: "numeric_state", vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Optional(CONF_ABOVE): vol.Coerce(float), vol.Optional(CONF_BELOW): vol.Coerce(float), vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) STATE_SCHEMA = vol.Schema( { CONF_PLATFORM: CONF_STATE, vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Required(CONF_TO_STATE): cv.string, vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) TEMPLATE_SCHEMA = vol.Schema( { CONF_PLATFORM: CONF_TEMPLATE, vol.Required(CONF_VALUE_TEMPLATE): cv.template, vol.Required(CONF_P_GIVEN_T): vol.Coerce(float), vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float), }, required=True, ) PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_DEVICE_CLASS): cv.string, vol.Required(CONF_OBSERVATIONS): vol.Schema( vol.All( cv.ensure_list, [vol.Any(NUMERIC_STATE_SCHEMA, STATE_SCHEMA, TEMPLATE_SCHEMA)], ) ), vol.Required(CONF_PRIOR): vol.Coerce(float), vol.Optional( CONF_PROBABILITY_THRESHOLD, default=DEFAULT_PROBABILITY_THRESHOLD ): vol.Coerce(float), } ) def update_probability(prior, prob_true, prob_false): """Update probability using Bayes' rule.""" numerator = prob_true * prior denominator = numerator + prob_false * (1 - prior) probability = numerator / denominator return probability async def async_setup_platform(hass, config, async_add_entities, discovery_info=None): """Set up the Bayesian Binary sensor.""" name = config.get(CONF_NAME) observations = config.get(CONF_OBSERVATIONS) prior = config.get(CONF_PRIOR) probability_threshold = config.get(CONF_PROBABILITY_THRESHOLD) device_class = config.get(CONF_DEVICE_CLASS) async_add_entities( [ BayesianBinarySensor( name, prior, observations, probability_threshold, device_class ) ], True, ) class BayesianBinarySensor(BinarySensorDevice): """Representation of a Bayesian sensor.""" def __init__(self, name, prior, observations, probability_threshold, device_class): """Initialize the Bayesian sensor.""" self._name = name self._observations = observations self._probability_threshold = probability_threshold self._device_class = device_class self._deviation = False self.prior = prior self.probability = prior self.current_obs = OrderedDict({}) self.entity_obs_dict = [] for obs in self._observations: if "entity_id" in obs: self.entity_obs_dict.append([obs.get("entity_id")]) if "value_template" in obs: self.entity_obs_dict.append( list(obs.get(CONF_VALUE_TEMPLATE).extract_entities()) ) to_observe = set() for obs in self._observations: if "entity_id" in obs: to_observe.update(set([obs.get("entity_id")])) if "value_template" in obs: to_observe.update(set(obs.get(CONF_VALUE_TEMPLATE).extract_entities())) self.entity_obs = {key: [] for key in to_observe} for ind, obs in enumerate(self._observations): obs["id"] = ind if "entity_id" in obs: self.entity_obs[obs["entity_id"]].append(obs) if "value_template" in obs: for ent in obs.get(CONF_VALUE_TEMPLATE).extract_entities(): self.entity_obs[ent].append(obs) self.watchers = { "numeric_state": self._process_numeric_state, "state": self._process_state, "template": self._process_template, } async def async_added_to_hass(self): """Call when entity about to be added.""" @callback def async_threshold_sensor_state_listener(entity, old_state, new_state): """Handle sensor state changes.""" if new_state.state == STATE_UNKNOWN: return entity_obs_list = self.entity_obs[entity] for entity_obs in entity_obs_list: platform = entity_obs["platform"] self.watchers[platform](entity_obs) prior = self.prior for obs in self.current_obs.values(): prior = update_probability(prior, obs["prob_true"], obs["prob_false"]) self.probability = prior self.hass.async_add_job(self.async_update_ha_state, True) async_track_state_change( self.hass, self.entity_obs, async_threshold_sensor_state_listener ) def _update_current_obs(self, entity_observation, should_trigger): """Update current observation.""" obs_id = entity_observation["id"] if should_trigger: prob_true = entity_observation["prob_given_true"] prob_false = entity_observation.get("prob_given_false", 1 - prob_true) self.current_obs[obs_id] = { "prob_true": prob_true, "prob_false": prob_false, } else: self.current_obs.pop(obs_id, None) def _process_numeric_state(self, entity_observation): """Add entity to current_obs if numeric state conditions are met.""" entity = entity_observation["entity_id"] should_trigger = condition.async_numeric_state( self.hass, entity, entity_observation.get("below"), entity_observation.get("above"), None, entity_observation, ) self._update_current_obs(entity_observation, should_trigger) def _process_state(self, entity_observation): """Add entity to current observations if state conditions are met.""" entity = entity_observation["entity_id"] should_trigger = condition.state( self.hass, entity, entity_observation.get("to_state") ) self._update_current_obs(entity_observation, should_trigger) def _process_template(self, entity_observation): """Add entity to current_obs if template is true.""" template = entity_observation.get(CONF_VALUE_TEMPLATE) template.hass = self.hass should_trigger = condition.async_template( self.hass, template, entity_observation ) self._update_current_obs(entity_observation, should_trigger) @property def name(self): """Return the name of the sensor.""" return self._name @property def is_on(self): """Return true if sensor is on.""" return self._deviation @property def should_poll(self): """No polling needed.""" return False @property def device_class(self): """Return the sensor class of the sensor.""" return self._device_class @property def device_state_attributes(self): """Return the state attributes of the sensor.""" return { ATTR_OBSERVATIONS: list(self.current_obs.values()), ATTR_OCCURRED_OBSERVATION_ENTITIES: list( set( chain.from_iterable( self.entity_obs_dict[obs] for obs in self.current_obs.keys() ) ) ), ATTR_PROBABILITY: round(self.probability, 2), ATTR_PROBABILITY_THRESHOLD: self._probability_threshold, } async def async_update(self): """Get the latest data and update the states.""" self._deviation = bool(self.probability >= self._probability_threshold)