Bayesian Binary Sensor (#8810)
* Bayesian Binary Sensor Why: * It would be beneficial to leverage various sensor outputs in a Bayesian manner in order to sense more complex events. This change addresses the need by: * `BayesianBinarySensor` class in `./homeassistant/components/binary_sensor/bayesian.py` * Tests in `./tests/components/binary_sensor/test_bayesian.py` Caveats: This is my first time in this code-base. I did try to follow conventions that I was able to find, but I'm sure there will be some issues to straighten out. * minor cleanup * Address reviewer's comments This change addresses the need by: * Removing `CONF_SENSOR_CLASS` and its usage in `get_deprecated`. * Make probability update function a static method, and use single `_` to match project conventions. * Address linter failures * fix `device_class` declaration * Address Comments Why: * Not validating config schema enough. * Not following common practices for async initialization. * Naive implementation of Bayes' rule. This change addresses the need by: * Improving config validation for observations. * Moving initialization logic into `async_added_to_hass`. * Re-configuring Bayesian updates to allow true P|Q usage. * address linting issues * Improve DRYness by adding `_update_current_obs` method * update doc strings and ensure functions are set up properly for async * Make only 1 state change handle * fix style * fix style part 2 * fix lintpull/9183/head
parent
0b58d5405e
commit
7de73e9ef7
|
@ -0,0 +1,211 @@
|
|||
"""
|
||||
Use Bayesian Inference to trigger a binary sensor.
|
||||
|
||||
For more details about this platform, please refer to the documentation at
|
||||
https://home-assistant.io/components/binary_sensor.bayesian/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.components.binary_sensor import (
|
||||
BinarySensorDevice, PLATFORM_SCHEMA)
|
||||
from homeassistant.const import (
|
||||
CONF_ABOVE, CONF_BELOW, CONF_DEVICE_CLASS, CONF_ENTITY_ID, CONF_NAME,
|
||||
CONF_PLATFORM, CONF_STATE, STATE_UNKNOWN)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import condition
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONF_OBSERVATIONS = 'observations'
|
||||
CONF_PRIOR = 'prior'
|
||||
CONF_PROBABILITY_THRESHOLD = 'probability_threshold'
|
||||
CONF_P_GIVEN_F = 'prob_given_false'
|
||||
CONF_P_GIVEN_T = 'prob_given_true'
|
||||
CONF_TO_STATE = 'to_state'
|
||||
|
||||
DEFAULT_NAME = 'BayesianBinary'
|
||||
|
||||
NUMERIC_STATE_SCHEMA = vol.Schema({
|
||||
CONF_PLATFORM: 'numeric_state',
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_id,
|
||||
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||||
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||||
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
|
||||
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float)
|
||||
}, required=True)
|
||||
|
||||
STATE_SCHEMA = vol.Schema({
|
||||
CONF_PLATFORM: CONF_STATE,
|
||||
vol.Required(CONF_ENTITY_ID): cv.entity_id,
|
||||
vol.Required(CONF_TO_STATE): cv.string,
|
||||
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
|
||||
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float)
|
||||
}, required=True)
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME):
|
||||
cv.string,
|
||||
vol.Optional(CONF_DEVICE_CLASS): cv.string,
|
||||
vol.Required(CONF_OBSERVATIONS): vol.Schema(
|
||||
vol.All(cv.ensure_list, [vol.Any(NUMERIC_STATE_SCHEMA,
|
||||
STATE_SCHEMA)])
|
||||
),
|
||||
vol.Required(CONF_PRIOR): vol.Coerce(float),
|
||||
vol.Optional(CONF_PROBABILITY_THRESHOLD):
|
||||
vol.Coerce(float),
|
||||
})
|
||||
|
||||
|
||||
def update_probability(prior, prob_true, prob_false):
|
||||
"""Update probability using Bayes' rule."""
|
||||
numerator = prob_true * prior
|
||||
denominator = numerator + prob_false * (1 - prior)
|
||||
|
||||
probability = numerator / denominator
|
||||
return probability
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
|
||||
"""Set up the Threshold sensor."""
|
||||
name = config.get(CONF_NAME)
|
||||
observations = config.get(CONF_OBSERVATIONS)
|
||||
prior = config.get(CONF_PRIOR)
|
||||
probability_threshold = config.get(CONF_PROBABILITY_THRESHOLD, 0.5)
|
||||
device_class = config.get(CONF_DEVICE_CLASS)
|
||||
|
||||
async_add_devices([
|
||||
BayesianBinarySensor(name, prior, observations, probability_threshold,
|
||||
device_class)
|
||||
], True)
|
||||
|
||||
|
||||
class BayesianBinarySensor(BinarySensorDevice):
|
||||
"""Representation of a Bayesian sensor."""
|
||||
|
||||
def __init__(self, name, prior, observations, probability_threshold,
|
||||
device_class):
|
||||
"""Initialize the Bayesian sensor."""
|
||||
self._name = name
|
||||
self._observations = observations
|
||||
self._probability_threshold = probability_threshold
|
||||
self._device_class = device_class
|
||||
self._deviation = False
|
||||
self.prior = prior
|
||||
self.probability = prior
|
||||
|
||||
self.current_obs = OrderedDict({})
|
||||
|
||||
self.entity_obs = {obs['entity_id']: obs for obs in self._observations}
|
||||
|
||||
self.watchers = {
|
||||
'numeric_state': self._process_numeric_state,
|
||||
'state': self._process_state
|
||||
}
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_added_to_hass(self):
|
||||
"""Call when entity about to be added to hass."""
|
||||
@callback
|
||||
# pylint: disable=invalid-name
|
||||
def async_threshold_sensor_state_listener(entity, old_state,
|
||||
new_state):
|
||||
"""Handle sensor state changes."""
|
||||
if new_state.state == STATE_UNKNOWN:
|
||||
return
|
||||
|
||||
entity_obs = self.entity_obs[entity]
|
||||
platform = entity_obs['platform']
|
||||
|
||||
self.watchers[platform](entity_obs)
|
||||
|
||||
prior = self.prior
|
||||
print(self.current_obs.values())
|
||||
for obs in self.current_obs.values():
|
||||
prior = update_probability(prior, obs['prob_true'],
|
||||
obs['prob_false'])
|
||||
|
||||
self.probability = prior
|
||||
|
||||
self.hass.async_add_job(self.async_update_ha_state, True)
|
||||
|
||||
entities = [obs['entity_id'] for obs in self._observations]
|
||||
async_track_state_change(
|
||||
self.hass, entities, async_threshold_sensor_state_listener)
|
||||
|
||||
def _update_current_obs(self, entity_observation, should_trigger):
|
||||
"""Update current observation."""
|
||||
entity = entity_observation['entity_id']
|
||||
|
||||
if should_trigger:
|
||||
prob_true = entity_observation['prob_given_true']
|
||||
prob_false = entity_observation.get(
|
||||
'prob_given_false', 1 - prob_true)
|
||||
|
||||
self.current_obs[entity] = {
|
||||
'prob_true': prob_true,
|
||||
'prob_false': prob_false
|
||||
}
|
||||
|
||||
else:
|
||||
self.current_obs.pop(entity, None)
|
||||
|
||||
def _process_numeric_state(self, entity_observation):
|
||||
"""Add entity to current_obs if numeric state conditions are met."""
|
||||
entity = entity_observation['entity_id']
|
||||
|
||||
should_trigger = condition.async_numeric_state(
|
||||
self.hass, entity,
|
||||
entity_observation.get('below'),
|
||||
entity_observation.get('above'), None, entity_observation)
|
||||
|
||||
self._update_current_obs(entity_observation, should_trigger)
|
||||
|
||||
def _process_state(self, entity_observation):
|
||||
"""Add entity to current observations if state conditions are met."""
|
||||
entity = entity_observation['entity_id']
|
||||
|
||||
should_trigger = condition.state(
|
||||
self.hass, entity, entity_observation.get('to_state'))
|
||||
|
||||
self._update_current_obs(entity_observation, should_trigger)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the sensor."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
"""Return true if sensor is on."""
|
||||
return self._deviation
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""No polling needed."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def device_class(self):
|
||||
"""Return the sensor class of the sensor."""
|
||||
return self._device_class
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
"""Return the state attributes of the sensor."""
|
||||
return {
|
||||
'observations': [val for val in self.current_obs.values()],
|
||||
'probability': self.probability,
|
||||
'probability_threshold': self._probability_threshold
|
||||
}
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update(self):
|
||||
"""Get the latest data and update the states."""
|
||||
self._deviation = bool(self.probability > self._probability_threshold)
|
|
@ -0,0 +1,176 @@
|
|||
"""The test for the bayesian sensor platform."""
|
||||
import unittest
|
||||
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components.binary_sensor import bayesian
|
||||
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
|
||||
class TestBayesianBinarySensor(unittest.TestCase):
|
||||
"""Test the threshold sensor."""
|
||||
|
||||
def setup_method(self, method):
|
||||
"""Set up things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
|
||||
def teardown_method(self, method):
|
||||
"""Stop everything that was started."""
|
||||
self.hass.stop()
|
||||
|
||||
def test_sensor_numeric_state(self):
|
||||
"""Test sensor on numeric state platform observations."""
|
||||
config = {
|
||||
'binary_sensor': {
|
||||
'platform':
|
||||
'bayesian',
|
||||
'name':
|
||||
'Test_Binary',
|
||||
'observations': [{
|
||||
'platform': 'numeric_state',
|
||||
'entity_id': 'sensor.test_monitored',
|
||||
'below': 10,
|
||||
'above': 5,
|
||||
'prob_given_true': 0.6
|
||||
}, {
|
||||
'platform': 'numeric_state',
|
||||
'entity_id': 'sensor.test_monitored1',
|
||||
'below': 7,
|
||||
'above': 5,
|
||||
'prob_given_true': 0.9,
|
||||
'prob_given_false': 0.1
|
||||
}],
|
||||
'prior':
|
||||
0.2,
|
||||
}
|
||||
}
|
||||
|
||||
assert setup_component(self.hass, 'binary_sensor', config)
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 4)
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
|
||||
self.assertEqual([], state.attributes.get('observations'))
|
||||
self.assertEqual(0.2, state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'off'
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 6)
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 4)
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 6)
|
||||
self.hass.states.set('sensor.test_monitored1', 6)
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
self.assertEqual([{
|
||||
'prob_false': 0.4,
|
||||
'prob_true': 0.6
|
||||
}, {
|
||||
'prob_false': 0.1,
|
||||
'prob_true': 0.9
|
||||
}], state.attributes.get('observations'))
|
||||
self.assertAlmostEqual(0.7714285714285715,
|
||||
state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'on'
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 6)
|
||||
self.hass.states.set('sensor.test_monitored1', 0)
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 4)
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
self.assertEqual(0.2, state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'off'
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 15)
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
|
||||
assert state.state == 'off'
|
||||
|
||||
def test_sensor_state(self):
|
||||
"""Test sensor on state platform observations."""
|
||||
config = {
|
||||
'binary_sensor': {
|
||||
'name':
|
||||
'Test_Binary',
|
||||
'platform':
|
||||
'bayesian',
|
||||
'observations': [{
|
||||
'platform': 'state',
|
||||
'entity_id': 'sensor.test_monitored',
|
||||
'to_state': 'off',
|
||||
'prob_given_true': 0.8,
|
||||
'prob_given_false': 0.4
|
||||
}],
|
||||
'prior':
|
||||
0.2,
|
||||
'probability_threshold':
|
||||
0.32,
|
||||
}
|
||||
}
|
||||
|
||||
assert setup_component(self.hass, 'binary_sensor', config)
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 'on')
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
|
||||
self.assertEqual([], state.attributes.get('observations'))
|
||||
self.assertEqual(0.2, state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'off'
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 'off')
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 'on')
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 'off')
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
self.assertEqual([{
|
||||
'prob_true': 0.8,
|
||||
'prob_false': 0.4
|
||||
}], state.attributes.get('observations'))
|
||||
self.assertAlmostEqual(0.33333333, state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'on'
|
||||
|
||||
self.hass.states.set('sensor.test_monitored', 'off')
|
||||
self.hass.block_till_done()
|
||||
self.hass.states.set('sensor.test_monitored', 'on')
|
||||
self.hass.block_till_done()
|
||||
|
||||
state = self.hass.states.get('binary_sensor.test_binary')
|
||||
self.assertAlmostEqual(0.2, state.attributes.get('probability'))
|
||||
|
||||
assert state.state == 'off'
|
||||
|
||||
def test_probability_updates(self):
|
||||
"""Test probability update function."""
|
||||
prob_true = [0.3, 0.6, 0.8]
|
||||
prob_false = [0.7, 0.4, 0.2]
|
||||
prior = 0.5
|
||||
|
||||
for pt, pf in zip(prob_true, prob_false):
|
||||
prior = bayesian.update_probability(prior, pt, pf)
|
||||
|
||||
self.assertAlmostEqual(0.720000, prior)
|
||||
|
||||
prob_true = [0.8, 0.3, 0.9]
|
||||
prob_false = [0.6, 0.4, 0.2]
|
||||
prior = 0.7
|
||||
|
||||
for pt, pf in zip(prob_true, prob_false):
|
||||
prior = bayesian.update_probability(prior, pt, pf)
|
||||
|
||||
self.assertAlmostEqual(0.9130434782608695, prior)
|
Loading…
Reference in New Issue