diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 3da4ccd1880..764f8ed49af 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -78,13 +78,13 @@ SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA = vol.Schema( # Integration that provided the entity vol.Optional("integration"): str, # Domain the entity belongs to - vol.Optional("domain"): str, + vol.Optional("domain"): vol.Any(str, [str]), # Device class of the entity vol.Optional("device_class"): str, } ) -DEVICE_SELECTOR_CONFIG_SCHEMA = vol.Schema( +SINGLE_DEVICE_SELECTOR_CONFIG_SCHEMA = vol.Schema( { # Integration linked to it with a config entry vol.Optional("integration"): str, @@ -94,7 +94,6 @@ DEVICE_SELECTOR_CONFIG_SCHEMA = vol.Schema( vol.Optional("model"): str, # Device has to contain entities matching this selector vol.Optional("entity"): SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA, - vol.Optional("multiple", default=False): cv.boolean, } ) @@ -140,7 +139,7 @@ class AreaSelector(Selector): CONFIG_SCHEMA = vol.Schema( { vol.Optional("entity"): SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA, - vol.Optional("device"): DEVICE_SELECTOR_CONFIG_SCHEMA, + vol.Optional("device"): SINGLE_DEVICE_SELECTOR_CONFIG_SCHEMA, vol.Optional("multiple", default=False): cv.boolean, } ) @@ -183,13 +182,82 @@ class BooleanSelector(Selector): return value +@SELECTORS.register("color_rgb") +class ColorRGBSelector(Selector): + """Selector of an RGB color value.""" + + selector_type = "color_rgb" + + CONFIG_SCHEMA = vol.Schema({}) + + def __call__(self, data: Any) -> list[int]: + """Validate the passed selection.""" + value: list[int] = vol.All(list, vol.ExactSequence((cv.byte,) * 3))(data) + return value + + +@SELECTORS.register("color_temp") +class ColorTempSelector(Selector): + """Selector of an color temperature.""" + + selector_type = "color_temp" + + CONFIG_SCHEMA = vol.Schema( + { + vol.Optional("max_mireds"): vol.Coerce(int), + vol.Optional("min_mireds"): vol.Coerce(int), + } + ) + + def __call__(self, data: Any) -> int: + """Validate the passed selection.""" + value: int = vol.All( + vol.Coerce(float), + vol.Range( + min=self.config.get("min_mireds"), + max=self.config.get("max_mireds"), + ), + )(data) + return value + + +@SELECTORS.register("date") +class DateSelector(Selector): + """Selector of a date.""" + + selector_type = "date" + + CONFIG_SCHEMA = vol.Schema({}) + + def __call__(self, data: Any) -> Any: + """Validate the passed selection.""" + cv.date(data) + return data + + +@SELECTORS.register("datetime") +class DateTimeSelector(Selector): + """Selector of a datetime.""" + + selector_type = "datetime" + + CONFIG_SCHEMA = vol.Schema({}) + + def __call__(self, data: Any) -> Any: + """Validate the passed selection.""" + cv.datetime(data) + return data + + @SELECTORS.register("device") class DeviceSelector(Selector): """Selector of a single or list of devices.""" selector_type = "device" - CONFIG_SCHEMA = DEVICE_SELECTOR_CONFIG_SCHEMA + CONFIG_SCHEMA = SINGLE_DEVICE_SELECTOR_CONFIG_SCHEMA.extend( + {vol.Optional("multiple", default=False): cv.boolean} + ) def __call__(self, data: Any) -> str | list[str]: """Validate the passed selection.""" @@ -226,23 +294,34 @@ class EntitySelector(Selector): selector_type = "entity" CONFIG_SCHEMA = SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA.extend( - {vol.Optional("multiple", default=False): cv.boolean} + { + vol.Optional("exclude_entities"): [str], + vol.Optional("include_entities"): [str], + vol.Optional("multiple", default=False): cv.boolean, + } ) def __call__(self, data: Any) -> str | list[str]: """Validate the passed selection.""" + include_entities = self.config.get("include_entities") + exclude_entities = self.config.get("exclude_entities") + def validate(e_or_u: str) -> str: e_or_u = cv.entity_id_or_uuid(e_or_u) if not valid_entity_id(e_or_u): return e_or_u - if allowed_domain := self.config.get("domain"): + if allowed_domains := cv.ensure_list(self.config.get("domain")): domain = split_entity_id(e_or_u)[0] - if domain != allowed_domain: + if domain not in allowed_domains: raise vol.Invalid( f"Entity {e_or_u} belongs to domain {domain}, " - f"expected {allowed_domain}" + f"expected {allowed_domains}" ) + if include_entities: + vol.In(include_entities)(e_or_u) + if exclude_entities: + vol.NotIn(exclude_entities)(e_or_u) return e_or_u if not self.config["multiple"]: @@ -460,7 +539,7 @@ class TargetSelector(Selector): CONFIG_SCHEMA = vol.Schema( { vol.Optional("entity"): SINGLE_ENTITY_SELECTOR_CONFIG_SCHEMA, - vol.Optional("device"): DEVICE_SELECTOR_CONFIG_SCHEMA, + vol.Optional("device"): SINGLE_DEVICE_SELECTOR_CONFIG_SCHEMA, } ) diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 5a2bc0c0baa..f1d7c83f211 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -110,6 +110,11 @@ def test_device_selector_schema(schema, valid_selections, invalid_selections): ({}, ("sensor.abc123", FAKE_UUID), (None, "abc123")), ({"integration": "zha"}, ("sensor.abc123", FAKE_UUID), (None, "abc123")), ({"domain": "light"}, ("light.abc123", FAKE_UUID), (None, "sensor.abc123")), + ( + {"domain": ["light", "sensor"]}, + ("light.abc123", "sensor.abc123", FAKE_UUID), + (None, "dog.abc123"), + ), ({"device_class": "motion"}, ("sensor.abc123", FAKE_UUID), (None, "abc123")), ( {"integration": "zha", "domain": "light"}, @@ -132,6 +137,26 @@ def test_device_selector_schema(schema, valid_selections, invalid_selections): ["sensor.abc123", "light.def456"], ), ), + ( + { + "include_entities": ["sensor.abc123", "sensor.def456", "sensor.ghi789"], + "exclude_entities": ["sensor.ghi789", "sensor.jkl123"], + }, + ("sensor.abc123", FAKE_UUID), + ("sensor.ghi789", "sensor.jkl123"), + ), + ( + { + "multiple": True, + "include_entities": ["sensor.abc123", "sensor.def456", "sensor.ghi789"], + "exclude_entities": ["sensor.ghi789", "sensor.jkl123"], + }, + (["sensor.abc123", "sensor.def456"], ["sensor.abc123", FAKE_UUID]), + ( + ["sensor.abc123", "sensor.jkl123"], + ["sensor.abc123", "sensor.ghi789"], + ), + ), ), ) def test_entity_selector_schema(schema, valid_selections, invalid_selections): @@ -490,3 +515,72 @@ def test_location_selector_schema(schema, valid_selections, invalid_selections): """Test location selector.""" _test_selector("location", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ([0, 0, 0], [255, 255, 255], [0.0, 0.0, 0.0], [255.0, 255.0, 255.0]), + (None, "abc", [0, 0, "nil"], (255, 255, 255)), + ), + ), +) +def test_rgb_color_selector_schema(schema, valid_selections, invalid_selections): + """Test color_rgb selector.""" + + _test_selector("color_rgb", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + (100, 100.0), + (None, "abc", [100]), + ), + ( + {"min_mireds": 100, "max_mireds": 200}, + (100, 200), + (99, 201), + ), + ), +) +def test_color_tempselector_schema(schema, valid_selections, invalid_selections): + """Test color_temp selector.""" + + _test_selector("color_temp", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ("2022-03-24",), + (None, "abc", "00:00", "2022-03-24 00:00", "2022-03-32"), + ), + ), +) +def test_date_selector_schema(schema, valid_selections, invalid_selections): + """Test date selector.""" + + _test_selector("date", schema, valid_selections, invalid_selections) + + +@pytest.mark.parametrize( + "schema,valid_selections,invalid_selections", + ( + ( + {}, + ("2022-03-24 00:00", "2022-03-24"), + (None, "abc", "00:00", "2022-03-24 24:01"), + ), + ), +) +def test_datetime_selector_schema(schema, valid_selections, invalid_selections): + """Test datetime selector.""" + + _test_selector("datetime", schema, valid_selections, invalid_selections)