Allow easier customization of whole domain, entity lists, globs. (#5215)

pull/5511/head
andrey-git 2017-01-22 21:19:50 +02:00 committed by Johann Kellerman
parent ab19577322
commit addc2c4340
8 changed files with 285 additions and 48 deletions

View File

@ -6,7 +6,7 @@ import os
import shutil
from types import MappingProxyType
# pylint: disable=unused-import
from typing import Any, Tuple # NOQA
from typing import Any, List, Tuple # NOQA
import voluptuous as vol
@ -14,15 +14,15 @@ from homeassistant.const import (
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
CONF_TIME_ZONE, CONF_CUSTOMIZE, CONF_ELEVATION, CONF_UNIT_SYSTEM_METRIC,
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
__version__)
from homeassistant.core import valid_entity_id, DOMAIN as CONF_CORE
CONF_ENTITY_ID, __version__)
from homeassistant.core import DOMAIN as CONF_CORE
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component
from homeassistant.util.yaml import load_yaml
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import set_customize
from homeassistant.util import dt as date_util, location as loc_util
from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM
from homeassistant.helpers.customize import set_customize
_LOGGER = logging.getLogger(__name__)
@ -87,19 +87,24 @@ tts:
"""
def _valid_customize(value):
"""Config validator for customize."""
if not isinstance(value, dict):
raise vol.Invalid('Expected dictionary')
CUSTOMIZE_SCHEMA_ENTRY = vol.Schema({
vol.Required(CONF_ENTITY_ID): vol.All(
cv.ensure_list_csv, vol.Length(min=1), [cv.string], [vol.Lower])
}, extra=vol.ALLOW_EXTRA)
for key, val in value.items():
if not valid_entity_id(key):
raise vol.Invalid('Invalid entity ID: {}'.format(key))
if not isinstance(val, dict):
raise vol.Invalid('Value of {} is not a dictionary'.format(key))
def _convert_old_config(inp: Any) -> List:
if not isinstance(inp, dict):
return cv.ensure_list(inp)
if CONF_ENTITY_ID in inp:
return [inp] # sigle entry
res = []
return value
inp = vol.Schema({cv.match_all: dict})(inp)
for key, val in inp.items():
val[CONF_ENTITY_ID] = key
res.append(val)
return res
PACKAGES_CONFIG_SCHEMA = vol.Schema({
@ -116,7 +121,8 @@ CORE_CONFIG_SCHEMA = vol.Schema({
CONF_UNIT_SYSTEM: cv.unit_system,
CONF_TIME_ZONE: cv.time_zone,
vol.Required(CONF_CUSTOMIZE,
default=MappingProxyType({})): _valid_customize,
default=MappingProxyType({})): vol.All(
_convert_old_config, [CUSTOMIZE_SCHEMA_ENTRY]),
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
})
@ -301,7 +307,7 @@ def async_process_ha_core_config(hass, config):
if CONF_TIME_ZONE in config:
set_time_zone(config.get(CONF_TIME_ZONE))
set_customize(config.get(CONF_CUSTOMIZE) or {})
set_customize(hass, config.get(CONF_CUSTOMIZE) or {})
if CONF_UNIT_SYSTEM in config:
if config[CONF_UNIT_SYSTEM] == CONF_UNIT_SYSTEM_IMPERIAL:

View File

@ -376,6 +376,8 @@ def ordered_dict(value_validator, key_validator=match_all):
"""Validate ordered dict."""
config = OrderedDict()
if not isinstance(value, dict):
raise vol.Invalid('Value {} is not a dictionary'.format(value))
for key, val in value.items():
v_res = item_validator({key: val})
config.update(v_res)
@ -385,6 +387,13 @@ def ordered_dict(value_validator, key_validator=match_all):
return validator
def ensure_list_csv(value: Any) -> Sequence:
"""Ensure that input is a list or make one from comma-separated string."""
if isinstance(value, str):
return [member.strip() for member in value.split(',')]
return ensure_list(value)
# Validator helpers
def key_dependency(key, dependency):

View File

@ -0,0 +1,80 @@
"""A helper module for customization."""
import collections
from typing import Dict, List
import fnmatch
from homeassistant.const import CONF_ENTITY_ID
from homeassistant.core import HomeAssistant, split_entity_id
_OVERWRITE_KEY = 'overwrite'
_OVERWRITE_CACHE_KEY = 'overwrite_cache'
def set_customize(hass: HomeAssistant, customize: List[Dict]) -> None:
"""Overwrite all current customize settings.
Async friendly.
"""
hass.data[_OVERWRITE_KEY] = customize
hass.data[_OVERWRITE_CACHE_KEY] = {}
def get_overrides(hass: HomeAssistant, entity_id: str) -> Dict:
"""Return a dictionary of overrides related to entity_id.
Whole-domain overrides are of lowest priorities,
then glob on entity ID, and finally exact entity_id
matches are of highest priority.
The lookups are cached.
"""
if _OVERWRITE_CACHE_KEY in hass.data and \
entity_id in hass.data[_OVERWRITE_CACHE_KEY]:
return hass.data[_OVERWRITE_CACHE_KEY][entity_id]
if _OVERWRITE_KEY not in hass.data:
return {}
domain_result = {} # type: Dict[str, Any]
glob_result = {} # type: Dict[str, Any]
exact_result = {} # type: Dict[str, Any]
domain = split_entity_id(entity_id)[0]
def clean_entry(entry: Dict) -> Dict:
"""Clean up entity-matching keys."""
entry.pop(CONF_ENTITY_ID, None)
return entry
def deep_update(target: Dict, source: Dict) -> None:
"""Deep update a dictionary."""
for key, value in source.items():
if isinstance(value, collections.Mapping):
updated_value = target.get(key, {})
# If the new value is map, but the old value is not -
# overwrite the old value.
if not isinstance(updated_value, collections.Mapping):
updated_value = {}
deep_update(updated_value, value)
target[key] = updated_value
else:
target[key] = source[key]
for rule in hass.data[_OVERWRITE_KEY]:
if CONF_ENTITY_ID in rule:
entities = rule[CONF_ENTITY_ID]
if domain in entities:
deep_update(domain_result, rule)
if entity_id in entities:
deep_update(exact_result, rule)
for entity_id_glob in entities:
if entity_id_glob == entity_id:
continue
if fnmatch.fnmatchcase(entity_id, entity_id_glob):
deep_update(glob_result, rule)
break
result = {}
deep_update(result, clean_entry(domain_result))
deep_update(result, clean_entry(glob_result))
deep_update(result, clean_entry(exact_result))
if _OVERWRITE_CACHE_KEY not in hass.data:
hass.data[_OVERWRITE_CACHE_KEY] = {}
hass.data[_OVERWRITE_CACHE_KEY][entity_id] = result
return result

View File

@ -4,7 +4,7 @@ import logging
import functools as ft
from timeit import default_timer as timer
from typing import Any, Optional, List, Dict
from typing import Optional, List
from homeassistant.const import (
ATTR_ASSUMED_STATE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ICON,
@ -16,9 +16,7 @@ from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
# Entity attributes that we will overwrite
_OVERWRITE = {} # type: Dict[str, Any]
from homeassistant.helpers.customize import get_overrides
_LOGGER = logging.getLogger(__name__)
@ -57,16 +55,6 @@ def async_generate_entity_id(entity_id_format: str, name: Optional[str],
entity_id_format.format(slugify(name)), current_ids)
def set_customize(customize: Dict[str, Any]) -> None:
"""Overwrite all current customize settings.
Async friendly.
"""
global _OVERWRITE
_OVERWRITE = {key.lower(): val for key, val in customize.items()}
class Entity(object):
"""An abstract class for Home Assistant entities."""
@ -254,7 +242,7 @@ class Entity(object):
end - start)
# Overwrite properties that have been set in the config file.
attr.update(_OVERWRITE.get(self.entity_id, {}))
attr.update(get_overrides(self.hass, self.entity_id))
# Remove hidden property if false so it won't show up.
if not attr.get(ATTR_HIDDEN, True):

View File

@ -165,6 +165,25 @@ def test_entity_ids():
]
def test_ensure_list_csv():
"""Test ensure_list_csv."""
schema = vol.Schema(cv.ensure_list_csv)
options = (
None,
12,
[],
['string'],
'string1,string2'
)
for value in options:
schema(value)
assert schema('string1, string2 ') == [
'string1', 'string2'
]
def test_event_schema():
"""Test event_schema validation."""
options = (
@ -429,6 +448,15 @@ def test_has_at_least_one_key():
schema(value)
def test_ordered_dict_only_dict():
"""Test ordered_dict validator."""
schema = vol.Schema(cv.ordered_dict(cv.match_all, cv.match_all))
for value in (None, [], 100, 'hello'):
with pytest.raises(vol.MultipleInvalid):
schema(value)
def test_ordered_dict_order():
"""Test ordered_dict validator."""
schema = vol.Schema(cv.ordered_dict(int, cv.string))

View File

@ -0,0 +1,87 @@
"""Test the customize helper."""
import homeassistant.helpers.customize as customize
class MockHass(object):
"""Mock object for HassAssistant."""
data = {}
class TestHelpersCustomize(object):
"""Test homeassistant.helpers.customize module."""
def setup_method(self, method):
"""Setup things to be run when tests are started."""
self.entity_id = 'test.test'
self.hass = MockHass()
def _get_overrides(self, overrides):
customize.set_customize(self.hass, overrides)
return customize.get_overrides(self.hass, self.entity_id)
def test_override_single_value(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': [self.entity_id], 'key': 'value'}])
assert result == {'key': 'value'}
def test_override_multiple_values(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': [self.entity_id], 'key1': 'value1'},
{'entity_id': [self.entity_id], 'key2': 'value2'}])
assert result == {'key1': 'value1', 'key2': 'value2'}
def test_override_same_value(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': [self.entity_id], 'key': 'value1'},
{'entity_id': [self.entity_id], 'key': 'value2'}])
assert result == {'key': 'value2'}
def test_override_by_domain(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': ['test'], 'key': 'value'}])
assert result == {'key': 'value'}
def test_override_by_glob(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': ['test.?e*'], 'key': 'value'}])
assert result == {'key': 'value'}
def test_override_exact_over_glob_over_domain(self):
"""Test entity customization through configuration."""
result = self._get_overrides([
{'entity_id': ['test.test'], 'key1': 'valueExact'},
{'entity_id': ['test.tes?'],
'key1': 'valueGlob',
'key2': 'valueGlob'},
{'entity_id': ['test'],
'key1': 'valueDomain',
'key2': 'valueDomain',
'key3': 'valueDomain'}])
assert result == {
'key1': 'valueExact',
'key2': 'valueGlob',
'key3': 'valueDomain'}
def test_override_deep_dict(self):
"""Test we can overwrite hidden property to True."""
result = self._get_overrides(
[{'entity_id': [self.entity_id],
'test': {'key1': 'value1', 'key2': 'value2'}},
{'entity_id': [self.entity_id],
'test': {'key3': 'value3', 'key2': 'value22'}}])
assert result['test'] == {
'key1': 'value1',
'key2': 'value22',
'key3': 'value3'}

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock
import pytest
import homeassistant.helpers.entity as entity
from homeassistant.helpers.customize import set_customize
from homeassistant.const import ATTR_HIDDEN
from tests.common import get_test_home_assistant
@ -78,7 +79,6 @@ class TestHelpersEntity(object):
def teardown_method(self, method):
"""Stop everything that was started."""
entity.set_customize({})
self.hass.stop()
def test_default_hidden_not_in_attributes(self):
@ -88,7 +88,9 @@ class TestHelpersEntity(object):
def test_overwriting_hidden_property_to_true(self):
"""Test we can overwrite hidden property to True."""
entity.set_customize({self.entity.entity_id: {ATTR_HIDDEN: True}})
set_customize(
self.hass,
[{'entity_id': [self.entity.entity_id], ATTR_HIDDEN: True}])
self.entity.update_ha_state()
state = self.hass.states.get(self.entity.entity_id)

View File

@ -170,16 +170,17 @@ class TestConfig(unittest.TestCase):
os.path.join(CONFIG_DIR, 'non_existing_dir/'), False))
self.assertTrue(mock_print.called)
# pylint: disable=no-self-use
def test_core_config_schema(self):
"""Test core config schema."""
for value in (
{CONF_UNIT_SYSTEM: 'K'},
{'time_zone': 'non-exist'},
{'latitude': '91'},
{'longitude': -181},
{'customize': 'bla'},
{'customize': {'invalid_entity_id': {}}},
{'customize': {'light.sensor': 100}},
{CONF_UNIT_SYSTEM: 'K'},
{'time_zone': 'non-exist'},
{'latitude': '91'},
{'longitude': -181},
{'customize': 'bla'},
{'customize': {'light.sensor': 100}},
{'customize': {'entity_id': []}},
):
with pytest.raises(MultipleInvalid):
config_util.CORE_CONFIG_SCHEMA(value)
@ -196,13 +197,7 @@ class TestConfig(unittest.TestCase):
},
})
def test_entity_customization(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50,
CONF_NAME: 'Test',
CONF_CUSTOMIZE: {'test.test': {'hidden': True}}}
def _compute_state(self, config):
run_coroutine_threadsafe(
config_util.async_process_ha_core_config(self.hass, config),
self.hass.loop).result()
@ -214,10 +209,50 @@ class TestConfig(unittest.TestCase):
self.hass.block_till_done()
state = self.hass.states.get('test.test')
return self.hass.states.get('test.test')
def test_entity_customization_false(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50,
CONF_NAME: 'Test',
CONF_CUSTOMIZE: {
'test.test': {'hidden': False}}}
state = self._compute_state(config)
assert 'hidden' not in state.attributes
def test_entity_customization(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50,
CONF_NAME: 'Test',
CONF_CUSTOMIZE: {'test.test': {'hidden': True}}}
state = self._compute_state(config)
assert state.attributes['hidden']
def test_entity_customization_comma_separated(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50,
CONF_NAME: 'Test',
CONF_CUSTOMIZE: [
{'entity_id': 'test.not_test,test,test.not_t*',
'key1': 'value1'},
{'entity_id': 'test.test,not_test,test.not_t*',
'key2': 'value2'},
{'entity_id': 'test.not_test,not_test,test.t*',
'key3': 'value3'}]}
state = self._compute_state(config)
assert state.attributes['key1'] == 'value1'
assert state.attributes['key2'] == 'value2'
assert state.attributes['key3'] == 'value3'
@mock.patch('homeassistant.config.shutil')
@mock.patch('homeassistant.config.os')
def test_remove_lib_on_upgrade(self, mock_os, mock_shutil):
@ -229,6 +264,7 @@ class TestConfig(unittest.TestCase):
mock_open = mock.mock_open()
with mock.patch('homeassistant.config.open', mock_open, create=True):
opened_file = mock_open.return_value
# pylint: disable=no-member
opened_file.readline.return_value = ha_version
self.hass.config.path = mock.Mock()
@ -258,6 +294,7 @@ class TestConfig(unittest.TestCase):
mock_open = mock.mock_open()
with mock.patch('homeassistant.config.open', mock_open, create=True):
opened_file = mock_open.return_value
# pylint: disable=no-member
opened_file.readline.return_value = ha_version
self.hass.config.path = mock.Mock()