Added config validator for future group platforms (#12592)

* Added cv.EntitiesDoamin(domain) validator

* Check if all entities in string or list belong to domain
* Added tests

* Use factory function and entity_ids

* Different error message

* Typo

* Added entity_domain validator for a single entity_id

* Image_processing platform now uses cv.entity_domain for source validation
pull/12687/head
cdce8p 2018-02-26 08:48:21 +01:00 committed by Paulus Schoutsen
parent 7d5c1581f1
commit 6e6ae173fd
3 changed files with 75 additions and 2 deletions

View File

@ -43,7 +43,7 @@ DEFAULT_TIMEOUT = 10
DEFAULT_CONFIDENCE = 80
SOURCE_SCHEMA = vol.Schema({
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Required(CONF_ENTITY_ID): cv.entity_domain('camera'),
vol.Optional(CONF_NAME): cv.string,
})

View File

@ -18,7 +18,7 @@ from homeassistant.const import (
CONF_ALIAS, CONF_ENTITY_ID, CONF_VALUE_TEMPLATE, WEEKDAYS,
CONF_CONDITION, CONF_BELOW, CONF_ABOVE, CONF_TIMEOUT, SUN_EVENT_SUNSET,
SUN_EVENT_SUNRISE, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC)
from homeassistant.core import valid_entity_id
from homeassistant.core import valid_entity_id, split_entity_id
from homeassistant.exceptions import TemplateError
import homeassistant.util.dt as dt_util
from homeassistant.util import slugify as util_slugify
@ -147,6 +147,29 @@ def entity_ids(value: Union[str, Sequence]) -> Sequence[str]:
return [entity_id(ent_id) for ent_id in value]
def entity_domain(domain: str):
"""Validate that entity belong to domain."""
def validate(value: Any) -> str:
"""Test if entity domain is domain."""
ent_domain = entities_domain(domain)
return ent_domain(value)[0]
return validate
def entities_domain(domain: str):
"""Validate that entities belong to domain."""
def validate(values: Union[str, Sequence]) -> Sequence[str]:
"""Test if entity domain is domain."""
values = entity_ids(values)
for ent_id in values:
if split_entity_id(ent_id)[0] != domain:
raise vol.Invalid(
"Entity ID '{}' does not belong to domain '{}'"
.format(ent_id, domain))
return values
return validate
def enum(enumClass):
"""Create validator for specified enum."""
return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__)

View File

@ -164,6 +164,55 @@ def test_entity_ids():
]
def test_entity_domain():
"""Test entity domain validation."""
schema = vol.Schema(cv.entity_domain('sensor'))
options = (
'invalid_entity',
'cover.demo',
)
for value in options:
with pytest.raises(vol.MultipleInvalid):
print(value)
schema(value)
assert schema('sensor.LIGHT') == 'sensor.light'
def test_entities_domain():
"""Test entities domain validation."""
schema = vol.Schema(cv.entities_domain('sensor'))
options = (
None,
'',
'invalid_entity',
['sensor.light', 'cover.demo'],
['sensor.light', 'sensor_invalid'],
)
for value in options:
with pytest.raises(vol.MultipleInvalid):
schema(value)
options = (
'sensor.light',
['SENSOR.light'],
['sensor.light', 'sensor.demo']
)
for value in options:
schema(value)
assert schema('sensor.LIGHT, sensor.demo ') == [
'sensor.light', 'sensor.demo'
]
assert schema(['sensor.light', 'SENSOR.demo']) == [
'sensor.light', 'sensor.demo'
]
def test_ensure_list_csv():
"""Test ensure_list_csv."""
schema = vol.Schema(cv.ensure_list_csv)
@ -453,6 +502,7 @@ def test_deprecated(caplog):
)
deprecated_schema({'venus': True})
# pylint: disable=len-as-condition
assert len(caplog.records) == 0
deprecated_schema({'mars': True})