diff --git a/homeassistant/components/image_processing/__init__.py b/homeassistant/components/image_processing/__init__.py index 2c2b8364823..061fd5d7074 100644 --- a/homeassistant/components/image_processing/__init__.py +++ b/homeassistant/components/image_processing/__init__.py @@ -43,7 +43,7 @@ DEFAULT_TIMEOUT = 10 DEFAULT_CONFIDENCE = 80 SOURCE_SCHEMA = vol.Schema({ - vol.Required(CONF_ENTITY_ID): cv.entity_id, + vol.Required(CONF_ENTITY_ID): cv.entity_domain('camera'), vol.Optional(CONF_NAME): cv.string, }) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index e32b041ffa2..f8f08fd118f 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -18,7 +18,7 @@ from homeassistant.const import ( CONF_ALIAS, CONF_ENTITY_ID, CONF_VALUE_TEMPLATE, WEEKDAYS, CONF_CONDITION, CONF_BELOW, CONF_ABOVE, CONF_TIMEOUT, SUN_EVENT_SUNSET, SUN_EVENT_SUNRISE, CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_METRIC) -from homeassistant.core import valid_entity_id +from homeassistant.core import valid_entity_id, split_entity_id from homeassistant.exceptions import TemplateError import homeassistant.util.dt as dt_util from homeassistant.util import slugify as util_slugify @@ -147,6 +147,29 @@ def entity_ids(value: Union[str, Sequence]) -> Sequence[str]: return [entity_id(ent_id) for ent_id in value] +def entity_domain(domain: str): + """Validate that entity belong to domain.""" + def validate(value: Any) -> str: + """Test if entity domain is domain.""" + ent_domain = entities_domain(domain) + return ent_domain(value)[0] + return validate + + +def entities_domain(domain: str): + """Validate that entities belong to domain.""" + def validate(values: Union[str, Sequence]) -> Sequence[str]: + """Test if entity domain is domain.""" + values = entity_ids(values) + for ent_id in values: + if split_entity_id(ent_id)[0] != domain: + raise vol.Invalid( + "Entity ID '{}' does not belong to domain '{}'" + .format(ent_id, domain)) + return values + return validate + + def enum(enumClass): """Create validator for specified enum.""" return vol.All(vol.In(enumClass.__members__), enumClass.__getitem__) diff --git a/tests/helpers/test_config_validation.py b/tests/helpers/test_config_validation.py index 66f0597fc93..90be56bbc7c 100644 --- a/tests/helpers/test_config_validation.py +++ b/tests/helpers/test_config_validation.py @@ -164,6 +164,55 @@ def test_entity_ids(): ] +def test_entity_domain(): + """Test entity domain validation.""" + schema = vol.Schema(cv.entity_domain('sensor')) + + options = ( + 'invalid_entity', + 'cover.demo', + ) + + for value in options: + with pytest.raises(vol.MultipleInvalid): + print(value) + schema(value) + + assert schema('sensor.LIGHT') == 'sensor.light' + + +def test_entities_domain(): + """Test entities domain validation.""" + schema = vol.Schema(cv.entities_domain('sensor')) + + options = ( + None, + '', + 'invalid_entity', + ['sensor.light', 'cover.demo'], + ['sensor.light', 'sensor_invalid'], + ) + + for value in options: + with pytest.raises(vol.MultipleInvalid): + schema(value) + + options = ( + 'sensor.light', + ['SENSOR.light'], + ['sensor.light', 'sensor.demo'] + ) + for value in options: + schema(value) + + assert schema('sensor.LIGHT, sensor.demo ') == [ + 'sensor.light', 'sensor.demo' + ] + assert schema(['sensor.light', 'SENSOR.demo']) == [ + 'sensor.light', 'sensor.demo' + ] + + def test_ensure_list_csv(): """Test ensure_list_csv.""" schema = vol.Schema(cv.ensure_list_csv) @@ -453,6 +502,7 @@ def test_deprecated(caplog): ) deprecated_schema({'venus': True}) + # pylint: disable=len-as-condition assert len(caplog.records) == 0 deprecated_schema({'mars': True})