Events and States are no longer dicts but objects.

pull/2/head
Paulus Schoutsen 2014-01-19 19:10:40 -08:00
parent ae2058de70
commit 3c3e7e5825
9 changed files with 212 additions and 129 deletions

View File

@ -2,19 +2,21 @@
homeassistant
~~~~~~~~~~~~~
Module to control the lights based on devices at home and the state of the sun.
Home Assistant is a Home Automation framework for observing the state
of objects and react to changes.
"""
import time
import logging
import threading
from collections import defaultdict, namedtuple
from datetime import datetime
import datetime as dt
import homeassistant.util as util
logging.basicConfig(level=logging.INFO)
ALL_EVENTS = '*'
MATCH_ALL = '*'
DOMAIN = "homeassistant"
@ -38,8 +40,6 @@ TIMER_INTERVAL = 10 # seconds
# every minute.
assert 60 % TIMER_INTERVAL == 0, "60 % TIMER_INTERVAL should be 0!"
DATE_STR_FORMAT = "%H:%M:%S %d-%m-%Y"
def start_home_assistant(bus):
""" Start home assistant. """
@ -60,37 +60,22 @@ def start_home_assistant(bus):
break
def datetime_to_str(dattim):
""" Converts datetime to a string format.
@rtype : str
"""
return dattim.strftime(DATE_STR_FORMAT)
def str_to_datetime(dt_str):
""" Converts a string to a datetime object.
@rtype: datetime
"""
return datetime.strptime(dt_str, DATE_STR_FORMAT)
def _ensure_list(parameter):
""" Wraps parameter in a list if it is not one and returns it.
@rtype : list
"""
return parameter if isinstance(parameter, list) else [parameter]
def _process_match_param(parameter):
""" Wraps parameter in a list if it is not one and returns it. """
if parameter is None:
return MATCH_ALL
elif isinstance(parameter, list):
return parameter
else:
return [parameter]
def _matcher(subject, pattern):
""" Returns True if subject matches the pattern.
Pattern is either a list of allowed subjects or a '*'.
@rtype : bool
Pattern is either a list of allowed subjects or a `MATCH_ALL`.
"""
return '*' in pattern or subject in pattern
return MATCH_ALL == pattern or subject in pattern
def split_state_category(category):
@ -98,36 +83,26 @@ def split_state_category(category):
return category.split(".", 1)
def filter_categories(categories, domain_filter=None, object_id_only=False):
""" Filter a list of categories based on domain. Setting object_id_only
def filter_categories(categories, domain_filter=None, strip_domain=False):
""" Filter a list of categories based on domain. Setting strip_domain
will only return the object_ids. """
return [
split_state_category(cat)[1] if object_id_only else cat
split_state_category(cat)[1] if strip_domain else cat
for cat in categories if
not domain_filter or cat.startswith(domain_filter)
]
def create_state(state, attributes=None, last_changed=None):
""" Creates a new state and initializes defaults where necessary. """
attributes = attributes or {}
last_changed = last_changed or datetime.now()
return {'state': state,
'attributes': attributes,
'last_changed': datetime_to_str(last_changed)}
def track_state_change(bus, category, action, from_state=None, to_state=None):
""" Helper method to track specific state changes. """
from_state = _ensure_list(from_state) if from_state else [ALL_EVENTS]
to_state = _ensure_list(to_state) if to_state else [ALL_EVENTS]
from_state = _process_match_param(from_state)
to_state = _process_match_param(to_state)
def listener(event):
""" State change listener that listens for specific state changes. """
if category == event.data['category'] and \
_matcher(event.data['old_state']['state'], from_state) and \
_matcher(event.data['new_state']['state'], to_state):
_matcher(event.data['old_state'].state, from_state) and \
_matcher(event.data['new_state'].state, to_state):
action(event.data['category'],
event.data['old_state'],
@ -138,19 +113,19 @@ def track_state_change(bus, category, action, from_state=None, to_state=None):
# pylint: disable=too-many-arguments
def track_time_change(bus, action,
year='*', month='*', day='*',
hour='*', minute='*', second='*',
year=None, month=None, day=None,
hour=None, minute=None, second=None,
point_in_time=None, listen_once=False):
""" Adds a listener that will listen for a specified or matching time. """
year, month = _ensure_list(year), _ensure_list(month)
day = _ensure_list(day)
year, month = _process_match_param(year), _process_match_param(month)
day = _process_match_param(day)
hour, minute = _ensure_list(hour), _ensure_list(minute)
second = _ensure_list(second)
hour, minute = _process_match_param(hour), _process_match_param(minute)
second = _process_match_param(second)
def listener(event):
""" Listens for matching time_changed events. """
now = str_to_datetime(event.data['now'])
now = event.data['now']
if (point_in_time and now > point_in_time) or \
(not point_in_time and
@ -180,7 +155,7 @@ class Bus(object):
"""
def __init__(self):
self._event_listeners = defaultdict(list)
self._event_listeners = {}
self._services = {}
self.logger = logging.getLogger(__name__)
@ -196,8 +171,7 @@ class Bus(object):
of listeners.
"""
return {key: len(self._event_listeners[key])
for key in self._event_listeners.keys()
if len(self._event_listeners[key]) > 0}
for key in self._event_listeners}
def call_service(self, domain, service, service_data=None):
""" Calls a service. """
@ -236,8 +210,16 @@ class Bus(object):
def fire_event(self, event_type, event_data=None):
""" Fire an event. """
if not event_data:
event_data = {}
# Copy the list of the current listeners because some listeners
# choose to remove themselves as a listener while being executed
# which causes the iterator to be confused.
listeners = self._event_listeners.get(MATCH_ALL, []) + \
self._event_listeners.get(event_type, [])
if not listeners:
return
event_data = event_data or {}
self.logger.info("Bus:Event {}: {}".format(
event_type, event_data))
@ -246,10 +228,7 @@ class Bus(object):
""" Fire listeners for event. """
event = Event(self, event_type, event_data)
# We do not use itertools.chain() because some listeners might
# choose to remove themselves as a listener while being executed
for listener in self._event_listeners[ALL_EVENTS] + \
self._event_listeners[event.event_type]:
for listener in listeners:
try:
listener(event)
@ -262,15 +241,19 @@ class Bus(object):
def listen_event(self, event_type, listener):
""" Listen for all events or events of a specific type.
To listen to all events specify the constant ``ALL_EVENTS``
To listen to all events specify the constant ``MATCH_ALL``
as event_type.
"""
self._event_listeners[event_type].append(listener)
try:
self._event_listeners[event_type].append(listener)
except KeyError: # event_type did not exist
self._event_listeners[event_type] = [listener]
def listen_once_event(self, event_type, listener):
""" Listen once for event of a specific type.
To listen to all events specify the constant ``ALL_EVENTS``
To listen to all events specify the constant ``MATCH_ALL``
as event_type.
Note: at the moment it is impossible to remove a one time listener.
@ -292,10 +275,67 @@ class Bus(object):
if len(self._event_listeners[event_type]) == 0:
del self._event_listeners[event_type]
except ValueError:
except (KeyError, ValueError):
pass
class State(object):
""" Object to represent a state within the state machine. """
def __init__(self, state, attributes=None, last_changed=None):
self.state = state
self.attributes = attributes or {}
last_changed = last_changed or dt.datetime.now()
# Strip microsecond from last_changed else we cannot guarantee
# state == State.from_json_dict(state.to_json_dict())
# This behavior occurs because to_json_dict strips microseconds
if last_changed.microsecond:
self.last_changed = last_changed - dt.timedelta(
microseconds=last_changed.microsecond)
else:
self.last_changed = last_changed
def to_json_dict(self, category=None):
""" Converts State to a dict to be used within JSON.
Ensures: state == State.from_json_dict(state.to_json_dict()) """
json_dict = {'state': self.state,
'attributes': self.attributes,
'last_changed': util.datetime_to_str(self.last_changed)}
if category:
json_dict['category'] = category
return json_dict
def copy(self):
""" Creates a copy of itself. """
return State(self.state, dict(self.attributes), self.last_changed)
@staticmethod
def from_json_dict(json_dict):
""" Static method to create a state from a dict.
Ensures: state == State.from_json_dict(state.to_json_dict()) """
try:
last_changed = json_dict.get('last_changed')
if last_changed:
last_changed = util.str_to_datetime(last_changed)
return State(json_dict['state'],
json_dict.get('attributes'),
last_changed)
except KeyError: # if key 'state' did not exist
return None
def __repr__(self):
return "{}({}, {})".format(
self.state, self.attributes,
util.datetime_to_str(self.last_changed))
class StateMachine(object):
""" Helper class that tracks the state of different categories. """
@ -333,16 +373,16 @@ class StateMachine(object):
# Add category if it does not exist
if category not in self.states:
self.states[category] = create_state(new_state, attributes)
self.states[category] = State(new_state, attributes)
# Change state and fire listeners
else:
old_state = self.states[category]
if old_state['state'] != new_state or \
old_state['attributes'] != attributes:
if old_state.state != new_state or \
old_state.attributes != attributes:
self.states[category] = create_state(new_state, attributes)
self.states[category] = State(new_state, attributes)
self.bus.fire_event(EVENT_STATE_CHANGED,
{'category': category,
@ -356,7 +396,7 @@ class StateMachine(object):
the state of the specified category. """
try:
# Make a copy so people won't mutate the state
return dict(self.states[category])
return self.states[category].copy()
except KeyError:
# If category does not exist
@ -366,7 +406,7 @@ class StateMachine(object):
""" Returns True if category exists and is specified state. """
cur_state = self.get_state(category)
return cur_state and cur_state['state'] == state
return cur_state and cur_state.state == state
class Timer(threading.Thread):
@ -389,7 +429,7 @@ class Timer(threading.Thread):
last_fired_on_second = -1
while True:
now = datetime.now()
now = dt.datetime.now()
# First check checks if we are not on a second matching the
# timer interval. Second check checks if we did not already fire
@ -407,12 +447,12 @@ class Timer(threading.Thread):
time.sleep(slp_seconds)
now = datetime.now()
now = dt.datetime.now()
last_fired_on_second = now.second
self.bus.fire_event(EVENT_TIME_CHANGED,
{'now': datetime_to_str(now)})
{'now': now})
class HomeAssistantException(Exception):

View File

@ -36,10 +36,10 @@ def turn_off(statemachine, cc_id=None):
state = statemachine.get_state(cat)
if state and \
state['state'] != STATE_NO_APP or \
state['state'] != pychromecast.APP_ID_HOME:
state.state != STATE_NO_APP or \
state.state != pychromecast.APP_ID_HOME:
pychromecast.quit_app(state['attributes'][ATTR_HOST])
pychromecast.quit_app(state.attributes[ATTR_HOST])
def setup(bus, statemachine, host):

View File

@ -92,7 +92,7 @@ def setup(bus, statemachine, light_group=None):
# Specific device came home ?
if (category != device_tracker.STATE_CATEGORY_ALL_DEVICES and
new_state['state'] == ha.STATE_HOME):
new_state.state == ha.STATE_HOME):
# These variables are needed for the elif check
now = datetime.now()
@ -128,7 +128,7 @@ def setup(bus, statemachine, light_group=None):
# Did all devices leave the house?
elif (category == device_tracker.STATE_CATEGORY_ALL_DEVICES and
new_state['state'] == ha.STATE_NOT_HOME and lights_are_on):
new_state.state == ha.STATE_NOT_HOME and lights_are_on):
logger.info(
"Everyone has left but there are devices on. Turning them off")

View File

@ -35,12 +35,11 @@ def is_on(statemachine, group):
state = statemachine.get_state(group)
if state:
group_type = _get_group_type(state['state'])
group_type = _get_group_type(state.state)
if group_type:
group_on = _GROUP_TYPES[group_type][0]
return state['state'] == group_on
# We found group_type, compare to ON-state
return state.state == _GROUP_TYPES[group_type][0]
else:
return False
else:
@ -51,7 +50,7 @@ def get_categories(statemachine, group):
""" Get the categories that make up this group. """
state = statemachine.get_state(group)
return state['attributes'][STATE_ATTR_CATEGORIES] if state else []
return state.attributes[STATE_ATTR_CATEGORIES] if state else []
# pylint: disable=too-many-branches
@ -73,7 +72,7 @@ def setup(bus, statemachine, name, categories):
# Try to determine group type if we didn't yet
if not group_type and state:
group_type = _get_group_type(state['state'])
group_type = _get_group_type(state.state)
if group_type:
group_on, group_off = _GROUP_TYPES[group_type]
@ -82,7 +81,7 @@ def setup(bus, statemachine, name, categories):
else:
# We did not find a matching group_type
errors.append("Found unexpected state '{}'".format(
name, state['state']))
name, state.state))
break
@ -91,13 +90,13 @@ def setup(bus, statemachine, name, categories):
errors.append("Category {} does not exist".format(cat))
# Check if category is valid state
elif state['state'] != group_off and state['state'] != group_on:
elif state.state != group_off and state.state != group_on:
errors.append("State of {} is {} (expected: {}, {})".format(
cat, state['state'], group_off, group_on))
cat, state.state, group_off, group_on))
# Keep track of the group state to init later on
elif group_state == group_off and state['state'] == group_on:
elif group_state == group_off and state.state == group_on:
group_state = group_on
if errors:
@ -114,17 +113,17 @@ def setup(bus, statemachine, name, categories):
""" Updates the group state based on a state change by a tracked
category. """
cur_group_state = statemachine.get_state(group_cat)['state']
cur_group_state = statemachine.get_state(group_cat).state
# if cur_group_state = OFF and new_state = ON: set ON
# if cur_group_state = ON and new_state = OFF: research
# else: ignore
if cur_group_state == group_off and new_state['state'] == group_on:
if cur_group_state == group_off and new_state.state == group_on:
statemachine.set_state(group_cat, group_on, state_attr)
elif cur_group_state == group_on and new_state['state'] == group_off:
elif cur_group_state == group_on and new_state.state == group_off:
# Check if any of the other states is still on
if not any([statemachine.is_state(cat, group_on)

View File

@ -341,16 +341,16 @@ class RequestHandler(BaseHTTPRequestHandler):
state = self.server.statemachine.get_state(category)
attributes = "<br>".join(
["{}: {}".format(attr, state['attributes'][attr])
for attr in state['attributes']])
["{}: {}".format(attr, state.attributes[attr])
for attr in state.attributes])
write(("<tr>"
"<td>{}</td><td>{}</td><td>{}</td><td>{}</td>"
"</tr>").format(
category,
state['state'],
state.state,
attributes,
state['last_changed']))
state.last_changed))
# Change state form
write(("<tr><td><input name='category' class='form-control' "
@ -518,9 +518,8 @@ class RequestHandler(BaseHTTPRequestHandler):
if self.use_json:
state = self.server.statemachine.get_state(category)
state['category'] = category
self._write_json(state, status_code=HTTP_CREATED,
self._write_json(state.to_json_dict(category),
status_code=HTTP_CREATED,
location=
URL_API_STATES_CATEGORY.format(category))
else:
@ -619,10 +618,7 @@ class RequestHandler(BaseHTTPRequestHandler):
state = self.server.statemachine.get_state(category)
if state:
state['category'] = category
self._write_json(state)
self._write_json(state.to_json_dict(category))
else:
# If category does not exist
self._message("State does not exist.", HTTP_UNPROCESSABLE_ENTITY)

View File

@ -8,6 +8,7 @@ import logging
from datetime import timedelta
import homeassistant as ha
import homeassistant.util as util
STATE_CATEGORY = "weather.sun"
@ -27,16 +28,16 @@ def next_setting(statemachine):
""" Returns the datetime object representing the next sun setting. """
state = statemachine.get_state(STATE_CATEGORY)
return None if not state else ha.str_to_datetime(
state['attributes'][STATE_ATTR_NEXT_SETTING])
return None if not state else util.str_to_datetime(
state.attributes[STATE_ATTR_NEXT_SETTING])
def next_rising(statemachine):
""" Returns the datetime object representing the next sun setting. """
state = statemachine.get_state(STATE_CATEGORY)
return None if not state else ha.str_to_datetime(
state['attributes'][STATE_ATTR_NEXT_RISING])
return None if not state else util.str_to_datetime(
state.attributes[STATE_ATTR_NEXT_RISING])
def setup(bus, statemachine, latitude, longitude):
@ -74,8 +75,8 @@ def setup(bus, statemachine, latitude, longitude):
next_change.strftime("%H:%M")))
state_attributes = {
STATE_ATTR_NEXT_RISING: ha.datetime_to_str(next_rising_dt),
STATE_ATTR_NEXT_SETTING: ha.datetime_to_str(next_setting_dt)
STATE_ATTR_NEXT_RISING: util.datetime_to_str(next_rising_dt),
STATE_ATTR_NEXT_SETTING: util.datetime_to_str(next_setting_dt)
}
statemachine.set_state(STATE_CATEGORY, new_state, state_attributes)

View File

@ -49,6 +49,18 @@ def _setup_call_api(host, port, api_password):
return _call_api
class JSONEncoder(json.JSONEncoder):
""" JSONEncoder that supports Home Assistant objects. """
def default(self, obj): # pylint: disable=method-hidden
""" Checks if Home Assistat object and encodes if possible.
Else hand it off to original method. """
if isinstance(obj, ha.State):
return obj.to_json_dict()
return json.JSONEncoder.default(self, obj)
class Bus(ha.Bus):
""" Drop-in replacement for a normal bus that will forward interaction to
a remote bus.
@ -140,7 +152,10 @@ class Bus(ha.Bus):
def fire_event(self, event_type, event_data=None):
""" Fire an event. """
data = {'event_data': json.dumps(event_data)} if event_data else None
if event_data:
data = {'event_data': json.dumps(event_data, cls=JSONEncoder)}
else:
data = None
req = self._call_api(METHOD_POST,
hah.URL_API_EVENTS_EVENT.format(event_type),
@ -159,6 +174,12 @@ class Bus(ha.Bus):
Will throw NotImplementedError. """
raise NotImplementedError
def listen_once_event(self, event_type, listener):
""" Not implemented for remote bus.
Will throw NotImplementedError. """
raise NotImplementedError
def remove_event_listener(self, event_type, listener):
""" Not implemented for remote bus.
@ -201,6 +222,13 @@ class StateMachine(ha.StateMachine):
self.logger.exception("StateMachine:Got unexpected result (2)")
return []
def remove_category(self, category):
""" This method is not implemented for remote statemachine.
Throws NotImplementedError. """
raise NotImplementedError
def set_state(self, category, new_state, attributes=None):
""" Set the state of a category, add category if it does not exist.
@ -243,9 +271,7 @@ class StateMachine(ha.StateMachine):
if req.status_code == 200:
data = req.json()
return ha.create_state(data['state'], data['attributes'],
ha.str_to_datetime(
data['last_changed']))
return ha.State.from_json_dict(data)
elif req.status_code == 422:
# Category does not exist

View File

@ -96,7 +96,7 @@ class TestHTTPInterface(unittest.TestCase):
"new_state": "debug_state_change2",
"api_password": API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test.test")['state'],
self.assertEqual(self.statemachine.get_state("test.test").state,
"debug_state_change2")
def test_debug_fire_event(self):
@ -138,14 +138,13 @@ class TestHTTPInterface(unittest.TestCase):
_url(hah.URL_API_STATES_CATEGORY.format("test")),
data={"api_password": API_PASSWORD})
data = req.json()
data = ha.State.from_json_dict(req.json())
state = self.statemachine.get_state("test")
self.assertEqual(data['category'], "test")
self.assertEqual(data['state'], state['state'])
self.assertEqual(data['last_changed'], state['last_changed'])
self.assertEqual(data['attributes'], state['attributes'])
self.assertEqual(data.state, state.state)
self.assertEqual(data.last_changed, state.last_changed)
self.assertEqual(data.attributes, state.attributes)
def test_api_get_non_existing_state(self):
""" Test if the debug interface allows us to get a state. """
@ -164,7 +163,7 @@ class TestHTTPInterface(unittest.TestCase):
data={"new_state": "debug_state_change2",
"api_password": API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test.test")['state'],
self.assertEqual(self.statemachine.get_state("test.test").state,
"debug_state_change2")
# pylint: disable=invalid-name
@ -181,7 +180,7 @@ class TestHTTPInterface(unittest.TestCase):
"api_password": API_PASSWORD})
cur_state = (self.statemachine.
get_state("test_category_that_does_not_exist")['state'])
get_state("test_category_that_does_not_exist").state)
self.assertEqual(req.status_code, 201)
self.assertEqual(cur_state, new_state)
@ -339,9 +338,9 @@ class TestRemote(unittest.TestCase):
state = self.statemachine.get_state("test")
self.assertEqual(remote_state['state'], state['state'])
self.assertEqual(remote_state['last_changed'], state['last_changed'])
self.assertEqual(remote_state['attributes'], state['attributes'])
self.assertEqual(remote_state.state, state.state)
self.assertEqual(remote_state.last_changed, state.last_changed)
self.assertEqual(remote_state.attributes, state.attributes)
def test_remote_sm_get_non_existing_state(self):
""" Test if the debug interface allows us to list state categories. """
@ -354,8 +353,8 @@ class TestRemote(unittest.TestCase):
state = self.statemachine.get_state("test")
self.assertEqual(state['state'], "set_remotely")
self.assertEqual(state['attributes']['test'], 1)
self.assertEqual(state.state, "set_remotely")
self.assertEqual(state.attributes['test'], 1)
def test_remote_eb_listening_for_same(self):
""" Test if remote EB correctly reports listener overview. """

View File

@ -1,10 +1,13 @@
""" Helper methods for various modules. """
import datetime
import re
RE_SANITIZE_FILENAME = re.compile(r"(~|(\.\.)|/|\+)")
RE_SLUGIFY = re.compile(r'[^A-Za-z0-9_]+')
DATE_STR_FORMAT = "%H:%M:%S %d-%m-%Y"
def sanitize_filename(filename):
""" Sanitizes a filename by removing .. / and \\. """
@ -16,3 +19,22 @@ def slugify(text):
text = text.strip().replace(" ", "_")
return RE_SLUGIFY.sub("", text)
def datetime_to_str(dattim):
""" Converts datetime to a string format.
@rtype : str
"""
return dattim.strftime(DATE_STR_FORMAT)
def str_to_datetime(dt_str):
""" Converts a string to a datetime object.
@rtype: datetime
"""
try:
return datetime.datetime.strptime(dt_str, DATE_STR_FORMAT)
except ValueError: # If dt_str did not match our format
return None