Add foundation for integration services (#30813)
* Add foundation for integration services * Fix tests * Remove async_get_platform * Migrate Sonos partially to EntityPlatform.async_register_entity_service * Tweaks * Move other Sonos services to media player domain * Move other Sonos services to media player domain * Address comments * Remove lock * Fix typos * Use make_entity_service_schema * Add area extraction to async_extract_entities Co-authored-by: Anders Melchiorsen <amelchio@nogoto.net>pull/30993/head
parent
f20b3515f2
commit
0c3ffbe282
|
@ -1,13 +1,10 @@
|
|||
"""Support to embed Sonos."""
|
||||
import asyncio
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.media_player import DOMAIN as MP_DOMAIN
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_TIME, CONF_HOSTS
|
||||
from homeassistant.const import CONF_HOSTS
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
@ -33,91 +30,12 @@ CONFIG_SCHEMA = vol.Schema(
|
|||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
SERVICE_JOIN = "join"
|
||||
SERVICE_UNJOIN = "unjoin"
|
||||
SERVICE_SNAPSHOT = "snapshot"
|
||||
SERVICE_RESTORE = "restore"
|
||||
SERVICE_SET_TIMER = "set_sleep_timer"
|
||||
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
|
||||
SERVICE_UPDATE_ALARM = "update_alarm"
|
||||
SERVICE_SET_OPTION = "set_option"
|
||||
SERVICE_PLAY_QUEUE = "play_queue"
|
||||
|
||||
ATTR_SLEEP_TIME = "sleep_time"
|
||||
ATTR_ALARM_ID = "alarm_id"
|
||||
ATTR_VOLUME = "volume"
|
||||
ATTR_ENABLED = "enabled"
|
||||
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
|
||||
ATTR_MASTER = "master"
|
||||
ATTR_WITH_GROUP = "with_group"
|
||||
ATTR_NIGHT_SOUND = "night_sound"
|
||||
ATTR_SPEECH_ENHANCE = "speech_enhance"
|
||||
ATTR_QUEUE_POSITION = "queue_position"
|
||||
|
||||
SONOS_JOIN_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_MASTER): cv.entity_id,
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
}
|
||||
)
|
||||
|
||||
SONOS_UNJOIN_SCHEMA = vol.Schema({vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids})
|
||||
|
||||
SONOS_STATES_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean,
|
||||
}
|
||||
)
|
||||
|
||||
SONOS_SET_TIMER_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Required(ATTR_SLEEP_TIME): vol.All(
|
||||
vol.Coerce(int), vol.Range(min=0, max=86399)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
SONOS_CLEAR_TIMER_SCHEMA = vol.Schema(
|
||||
{vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids}
|
||||
)
|
||||
|
||||
SONOS_UPDATE_ALARM_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Required(ATTR_ALARM_ID): cv.positive_int,
|
||||
vol.Optional(ATTR_TIME): cv.time,
|
||||
vol.Optional(ATTR_VOLUME): cv.small_float,
|
||||
vol.Optional(ATTR_ENABLED): cv.boolean,
|
||||
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
|
||||
}
|
||||
)
|
||||
|
||||
SONOS_SET_OPTION_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
|
||||
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
|
||||
}
|
||||
)
|
||||
|
||||
SONOS_PLAY_QUEUE_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_ENTITY_ID): cv.comp_entity_ids,
|
||||
vol.Optional(ATTR_QUEUE_POSITION, default=0): cv.positive_int,
|
||||
}
|
||||
)
|
||||
|
||||
DATA_SERVICE_EVENT = "sonos_service_idle"
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the Sonos component."""
|
||||
conf = config.get(DOMAIN)
|
||||
|
||||
hass.data[DOMAIN] = conf or {}
|
||||
hass.data[DATA_SERVICE_EVENT] = asyncio.Event()
|
||||
|
||||
if conf is not None:
|
||||
hass.async_create_task(
|
||||
|
@ -126,48 +44,6 @@ async def async_setup(hass, config):
|
|||
)
|
||||
)
|
||||
|
||||
async def service_handle(service):
|
||||
"""Dispatch a service call."""
|
||||
hass.data[DATA_SERVICE_EVENT].clear()
|
||||
async_dispatcher_send(hass, DOMAIN, service.service, service.data)
|
||||
await hass.data[DATA_SERVICE_EVENT].wait()
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_JOIN, service_handle, schema=SONOS_JOIN_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_UNJOIN, service_handle, schema=SONOS_UNJOIN_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SNAPSHOT, service_handle, schema=SONOS_STATES_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_RESTORE, service_handle, schema=SONOS_STATES_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SET_TIMER, service_handle, schema=SONOS_SET_TIMER_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_CLEAR_TIMER, service_handle, schema=SONOS_CLEAR_TIMER_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_UPDATE_ALARM, service_handle, schema=SONOS_UPDATE_ALARM_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SET_OPTION, service_handle, schema=SONOS_SET_OPTION_SCHEMA
|
||||
)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_PLAY_QUEUE, service_handle, schema=SONOS_PLAY_QUEUE_SCHEMA
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import pysonos
|
|||
from pysonos import alarms
|
||||
from pysonos.exceptions import SoCoException, SoCoUPnPException
|
||||
import pysonos.snapshot
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.media_player import MediaPlayerDevice
|
||||
from homeassistant.components.media_player.const import (
|
||||
|
@ -30,42 +31,16 @@ from homeassistant.components.media_player.const import (
|
|||
SUPPORT_VOLUME_MUTE,
|
||||
SUPPORT_VOLUME_SET,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ENTITY_MATCH_ALL,
|
||||
STATE_IDLE,
|
||||
STATE_PAUSED,
|
||||
STATE_PLAYING,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.const import ATTR_TIME, STATE_IDLE, STATE_PAUSED, STATE_PLAYING
|
||||
from homeassistant.core import ServiceCall, callback
|
||||
from homeassistant.helpers import config_validation as cv, entity_platform, service
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import (
|
||||
ATTR_ALARM_ID,
|
||||
ATTR_ENABLED,
|
||||
ATTR_INCLUDE_LINKED_ZONES,
|
||||
ATTR_MASTER,
|
||||
ATTR_NIGHT_SOUND,
|
||||
ATTR_QUEUE_POSITION,
|
||||
ATTR_SLEEP_TIME,
|
||||
ATTR_SPEECH_ENHANCE,
|
||||
ATTR_TIME,
|
||||
ATTR_VOLUME,
|
||||
ATTR_WITH_GROUP,
|
||||
CONF_ADVERTISE_ADDR,
|
||||
CONF_HOSTS,
|
||||
CONF_INTERFACE_ADDR,
|
||||
DATA_SERVICE_EVENT,
|
||||
DOMAIN as SONOS_DOMAIN,
|
||||
SERVICE_CLEAR_TIMER,
|
||||
SERVICE_JOIN,
|
||||
SERVICE_PLAY_QUEUE,
|
||||
SERVICE_RESTORE,
|
||||
SERVICE_SET_OPTION,
|
||||
SERVICE_SET_TIMER,
|
||||
SERVICE_SNAPSHOT,
|
||||
SERVICE_UNJOIN,
|
||||
SERVICE_UPDATE_ALARM,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -97,6 +72,27 @@ ATTR_SONOS_GROUP = "sonos_group"
|
|||
|
||||
UPNP_ERRORS_TO_IGNORE = ["701", "711", "712"]
|
||||
|
||||
SERVICE_JOIN = "join"
|
||||
SERVICE_UNJOIN = "unjoin"
|
||||
SERVICE_SNAPSHOT = "snapshot"
|
||||
SERVICE_RESTORE = "restore"
|
||||
SERVICE_SET_TIMER = "set_sleep_timer"
|
||||
SERVICE_CLEAR_TIMER = "clear_sleep_timer"
|
||||
SERVICE_UPDATE_ALARM = "update_alarm"
|
||||
SERVICE_SET_OPTION = "set_option"
|
||||
SERVICE_PLAY_QUEUE = "play_queue"
|
||||
|
||||
ATTR_SLEEP_TIME = "sleep_time"
|
||||
ATTR_ALARM_ID = "alarm_id"
|
||||
ATTR_VOLUME = "volume"
|
||||
ATTR_ENABLED = "enabled"
|
||||
ATTR_INCLUDE_LINKED_ZONES = "include_linked_zones"
|
||||
ATTR_MASTER = "master"
|
||||
ATTR_WITH_GROUP = "with_group"
|
||||
ATTR_NIGHT_SOUND = "night_sound"
|
||||
ATTR_SPEECH_ENHANCE = "speech_enhance"
|
||||
ATTR_QUEUE_POSITION = "queue_position"
|
||||
|
||||
|
||||
class SonosData:
|
||||
"""Storage class for platform global data."""
|
||||
|
@ -176,46 +172,101 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
|
|||
_LOGGER.debug("Adding discovery job")
|
||||
hass.async_add_executor_job(_discovery)
|
||||
|
||||
async def async_service_handle(service, data):
|
||||
platform = entity_platform.current_platform.get()
|
||||
|
||||
async def async_service_handle(service_call: ServiceCall):
|
||||
"""Handle dispatched services."""
|
||||
entity_ids = data.get("entity_id")
|
||||
entities = hass.data[DATA_SONOS].entities
|
||||
if entity_ids and entity_ids != ENTITY_MATCH_ALL:
|
||||
entities = [e for e in entities if e.entity_id in entity_ids]
|
||||
entities = await platform.async_extract_from_service(service_call)
|
||||
|
||||
if service == SERVICE_JOIN:
|
||||
master = [
|
||||
e
|
||||
for e in hass.data[DATA_SONOS].entities
|
||||
if e.entity_id == data[ATTR_MASTER]
|
||||
]
|
||||
if not entities:
|
||||
return
|
||||
|
||||
if service_call.service == SERVICE_JOIN:
|
||||
master = platform.entities.get(service_call.data[ATTR_MASTER])
|
||||
if master:
|
||||
await SonosEntity.join_multi(hass, master[0], entities)
|
||||
elif service == SERVICE_UNJOIN:
|
||||
await SonosEntity.join_multi(hass, master, entities)
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Invalid master specified for join service: %s",
|
||||
service_call.data[ATTR_MASTER],
|
||||
)
|
||||
elif service_call.service == SERVICE_UNJOIN:
|
||||
await SonosEntity.unjoin_multi(hass, entities)
|
||||
elif service == SERVICE_SNAPSHOT:
|
||||
await SonosEntity.snapshot_multi(hass, entities, data[ATTR_WITH_GROUP])
|
||||
elif service == SERVICE_RESTORE:
|
||||
await SonosEntity.restore_multi(hass, entities, data[ATTR_WITH_GROUP])
|
||||
else:
|
||||
for entity in entities:
|
||||
if service == SERVICE_SET_TIMER:
|
||||
call = entity.set_sleep_timer
|
||||
elif service == SERVICE_CLEAR_TIMER:
|
||||
call = entity.clear_sleep_timer
|
||||
elif service == SERVICE_UPDATE_ALARM:
|
||||
call = entity.set_alarm
|
||||
elif service == SERVICE_SET_OPTION:
|
||||
call = entity.set_option
|
||||
elif service == SERVICE_PLAY_QUEUE:
|
||||
call = entity.play_queue
|
||||
elif service_call.service == SERVICE_SNAPSHOT:
|
||||
await SonosEntity.snapshot_multi(
|
||||
hass, entities, service_call.data[ATTR_WITH_GROUP]
|
||||
)
|
||||
elif service_call.service == SERVICE_RESTORE:
|
||||
await SonosEntity.restore_multi(
|
||||
hass, entities, service_call.data[ATTR_WITH_GROUP]
|
||||
)
|
||||
|
||||
hass.async_add_executor_job(call, data)
|
||||
service.async_register_admin_service(
|
||||
hass,
|
||||
SONOS_DOMAIN,
|
||||
SERVICE_JOIN,
|
||||
async_service_handle,
|
||||
cv.make_entity_service_schema({vol.Required(ATTR_MASTER): cv.entity_id}),
|
||||
)
|
||||
|
||||
# We are ready for the next service call
|
||||
hass.data[DATA_SERVICE_EVENT].set()
|
||||
service.async_register_admin_service(
|
||||
hass,
|
||||
SONOS_DOMAIN,
|
||||
SERVICE_UNJOIN,
|
||||
async_service_handle,
|
||||
cv.make_entity_service_schema({}),
|
||||
)
|
||||
|
||||
async_dispatcher_connect(hass, SONOS_DOMAIN, async_service_handle)
|
||||
join_unjoin_schema = cv.make_entity_service_schema(
|
||||
{vol.Optional(ATTR_WITH_GROUP, default=True): cv.boolean}
|
||||
)
|
||||
|
||||
service.async_register_admin_service(
|
||||
hass, SONOS_DOMAIN, SERVICE_SNAPSHOT, async_service_handle, join_unjoin_schema
|
||||
)
|
||||
|
||||
service.async_register_admin_service(
|
||||
hass, SONOS_DOMAIN, SERVICE_RESTORE, async_service_handle, join_unjoin_schema
|
||||
)
|
||||
|
||||
platform.async_register_entity_service(
|
||||
SERVICE_SET_TIMER,
|
||||
{
|
||||
vol.Required(ATTR_SLEEP_TIME): vol.All(
|
||||
vol.Coerce(int), vol.Range(min=0, max=86399)
|
||||
)
|
||||
},
|
||||
"set_sleep_timer",
|
||||
)
|
||||
|
||||
platform.async_register_entity_service(SERVICE_CLEAR_TIMER, {}, "clear_sleep_timer")
|
||||
|
||||
platform.async_register_entity_service(
|
||||
SERVICE_UPDATE_ALARM,
|
||||
{
|
||||
vol.Required(ATTR_ALARM_ID): cv.positive_int,
|
||||
vol.Optional(ATTR_TIME): cv.time,
|
||||
vol.Optional(ATTR_VOLUME): cv.small_float,
|
||||
vol.Optional(ATTR_ENABLED): cv.boolean,
|
||||
vol.Optional(ATTR_INCLUDE_LINKED_ZONES): cv.boolean,
|
||||
},
|
||||
"set_alarm",
|
||||
)
|
||||
|
||||
platform.async_register_entity_service(
|
||||
SERVICE_SET_OPTION,
|
||||
{
|
||||
vol.Optional(ATTR_NIGHT_SOUND): cv.boolean,
|
||||
vol.Optional(ATTR_SPEECH_ENHANCE): cv.boolean,
|
||||
},
|
||||
"set_option",
|
||||
)
|
||||
|
||||
platform.async_register_entity_service(
|
||||
SERVICE_PLAY_QUEUE,
|
||||
{vol.Optional(ATTR_QUEUE_POSITION): cv.positive_int},
|
||||
"play_queue",
|
||||
)
|
||||
|
||||
|
||||
class _ProcessSonosEventQueue:
|
||||
|
@ -480,10 +531,10 @@ class SonosEntity(MediaPlayerDevice):
|
|||
|
||||
player = self.soco
|
||||
|
||||
def subscribe(service, action):
|
||||
def subscribe(sonos_service, action):
|
||||
"""Add a subscription to a pysonos service."""
|
||||
queue = _ProcessSonosEventQueue(action)
|
||||
sub = service.subscribe(auto_renew=True, event_queue=queue)
|
||||
sub = sonos_service.subscribe(auto_renew=True, event_queue=queue)
|
||||
self._subscriptions.append(sub)
|
||||
|
||||
subscribe(player.avTransport, self.update_media)
|
||||
|
@ -1147,52 +1198,53 @@ class SonosEntity(MediaPlayerDevice):
|
|||
|
||||
@soco_error()
|
||||
@soco_coordinator
|
||||
def set_sleep_timer(self, data):
|
||||
def set_sleep_timer(self, sleep_time):
|
||||
"""Set the timer on the player."""
|
||||
self.soco.set_sleep_timer(data[ATTR_SLEEP_TIME])
|
||||
self.soco.set_sleep_timer(sleep_time)
|
||||
|
||||
@soco_error()
|
||||
@soco_coordinator
|
||||
def clear_sleep_timer(self, data):
|
||||
def clear_sleep_timer(self):
|
||||
"""Clear the timer on the player."""
|
||||
self.soco.set_sleep_timer(None)
|
||||
|
||||
@soco_error()
|
||||
@soco_coordinator
|
||||
def set_alarm(self, data):
|
||||
def set_alarm(
|
||||
self, alarm_id, time=None, volume=None, enabled=None, include_linked_zones=None
|
||||
):
|
||||
"""Set the alarm clock on the player."""
|
||||
|
||||
alarm = None
|
||||
for one_alarm in alarms.get_alarms(self.soco):
|
||||
# pylint: disable=protected-access
|
||||
if one_alarm._alarm_id == str(data[ATTR_ALARM_ID]):
|
||||
if one_alarm._alarm_id == str(alarm_id):
|
||||
alarm = one_alarm
|
||||
if alarm is None:
|
||||
_LOGGER.warning("did not find alarm with id %s", data[ATTR_ALARM_ID])
|
||||
_LOGGER.warning("did not find alarm with id %s", alarm_id)
|
||||
return
|
||||
if ATTR_TIME in data:
|
||||
alarm.start_time = data[ATTR_TIME]
|
||||
if ATTR_VOLUME in data:
|
||||
alarm.volume = int(data[ATTR_VOLUME] * 100)
|
||||
if ATTR_ENABLED in data:
|
||||
alarm.enabled = data[ATTR_ENABLED]
|
||||
if ATTR_INCLUDE_LINKED_ZONES in data:
|
||||
alarm.include_linked_zones = data[ATTR_INCLUDE_LINKED_ZONES]
|
||||
if time is not None:
|
||||
alarm.start_time = time
|
||||
if volume is not None:
|
||||
alarm.volume = int(volume * 100)
|
||||
if enabled is not None:
|
||||
alarm.enabled = enabled
|
||||
if include_linked_zones is not None:
|
||||
alarm.include_linked_zones = include_linked_zones
|
||||
alarm.save()
|
||||
|
||||
@soco_error()
|
||||
def set_option(self, data):
|
||||
def set_option(self, night_sound=None, speech_enhance=None):
|
||||
"""Modify playback options."""
|
||||
if ATTR_NIGHT_SOUND in data and self._night_sound is not None:
|
||||
self.soco.night_mode = data[ATTR_NIGHT_SOUND]
|
||||
if night_sound is not None and self._night_sound is not None:
|
||||
self.soco.night_mode = night_sound
|
||||
|
||||
if ATTR_SPEECH_ENHANCE in data and self._speech_enhance is not None:
|
||||
self.soco.dialog_mode = data[ATTR_SPEECH_ENHANCE]
|
||||
if speech_enhance is not None and self._speech_enhance is not None:
|
||||
self.soco.dialog_mode = speech_enhance
|
||||
|
||||
@soco_error()
|
||||
def play_queue(self, data):
|
||||
def play_queue(self, queue_position=0):
|
||||
"""Start playing the queue."""
|
||||
self.soco.play_from_queue(data[ATTR_QUEUE_POSITION])
|
||||
self.soco.play_from_queue(queue_position)
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
|
|
|
@ -6,17 +6,15 @@ import logging
|
|||
|
||||
from homeassistant import config as conf_util
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
CONF_ENTITY_NAMESPACE,
|
||||
CONF_SCAN_INTERVAL,
|
||||
ENTITY_MATCH_ALL,
|
||||
)
|
||||
from homeassistant.const import CONF_ENTITY_NAMESPACE, CONF_SCAN_INTERVAL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_per_platform, discovery
|
||||
from homeassistant.helpers.config_validation import make_entity_service_schema
|
||||
from homeassistant.helpers.service import async_extract_entity_ids
|
||||
from homeassistant.helpers import (
|
||||
config_per_platform,
|
||||
config_validation as cv,
|
||||
discovery,
|
||||
service,
|
||||
)
|
||||
from homeassistant.loader import async_get_integration, bind_hass
|
||||
from homeassistant.setup import async_prepare_setup_platform
|
||||
|
||||
|
@ -166,39 +164,27 @@ class EntityComponent:
|
|||
await platform.async_reset()
|
||||
return True
|
||||
|
||||
async def async_extract_from_service(self, service, expand_group=True):
|
||||
async def async_extract_from_service(self, service_call, expand_group=True):
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
||||
Will return an empty list if entities specified but unknown.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
data_ent_id = service.data.get(ATTR_ENTITY_ID)
|
||||
|
||||
if data_ent_id is None:
|
||||
return []
|
||||
|
||||
if data_ent_id == ENTITY_MATCH_ALL:
|
||||
return [entity for entity in self.entities if entity.available]
|
||||
|
||||
entity_ids = await async_extract_entity_ids(self.hass, service, expand_group)
|
||||
return [
|
||||
entity
|
||||
for entity in self.entities
|
||||
if entity.available and entity.entity_id in entity_ids
|
||||
]
|
||||
return await service.async_extract_entities(
|
||||
self.hass, self.entities, service_call, expand_group
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_register_entity_service(self, name, schema, func, required_features=None):
|
||||
"""Register an entity service."""
|
||||
if isinstance(schema, dict):
|
||||
schema = make_entity_service_schema(schema)
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call):
|
||||
"""Handle the service."""
|
||||
service_name = f"{self.domain}.{name}"
|
||||
await self.hass.helpers.service.entity_service_call(
|
||||
self._platforms.values(), func, call, service_name, required_features
|
||||
self._platforms.values(), func, call, required_features
|
||||
)
|
||||
|
||||
self.hass.services.async_register(self.domain, name, handle_service, schema)
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||
from homeassistant.const import DEVICE_DEFAULT_NAME
|
||||
from homeassistant.core import callback, split_entity_id, valid_entity_id
|
||||
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
|
||||
from homeassistant.helpers import config_validation as cv, service
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
||||
from .entity_registry import DISABLED_INTEGRATION
|
||||
|
@ -194,7 +195,11 @@ class EntityPlatform:
|
|||
)
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.exception("Error while setting up platform %s", self.platform_name)
|
||||
logger.exception(
|
||||
"Error while setting up %s platform for %s",
|
||||
self.platform_name,
|
||||
self.domain,
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
warn_task.cancel()
|
||||
|
@ -449,6 +454,33 @@ class EntityPlatform:
|
|||
self._async_unsub_polling()
|
||||
self._async_unsub_polling = None
|
||||
|
||||
async def async_extract_from_service(self, service_call, expand_group=True):
|
||||
"""Extract all known and available entities from a service call.
|
||||
|
||||
Will return an empty list if entities specified but unknown.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
return await service.async_extract_entities(
|
||||
self.hass, self.entities.values(), service_call, expand_group
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_register_entity_service(self, name, schema, func, required_features=None):
|
||||
"""Register an entity service."""
|
||||
if isinstance(schema, dict):
|
||||
schema = cv.make_entity_service_schema(schema)
|
||||
|
||||
async def handle_service(call):
|
||||
"""Handle the service."""
|
||||
await service.entity_service_call(
|
||||
self.hass, [self], func, call, required_features
|
||||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
self.platform_name, name, handle_service, schema
|
||||
)
|
||||
|
||||
async def _update_entity_states(self, now: datetime) -> None:
|
||||
"""Update the states of all the polling entities.
|
||||
|
||||
|
|
|
@ -108,13 +108,31 @@ def extract_entity_ids(hass, service_call, expand_group=True):
|
|||
).result()
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_entities(hass, entities, service_call, expand_group=True):
|
||||
"""Extract a list of entity objects from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
"""
|
||||
data_ent_id = service_call.data.get(ATTR_ENTITY_ID)
|
||||
|
||||
if data_ent_id == ENTITY_MATCH_ALL:
|
||||
return [entity for entity in entities if entity.available]
|
||||
|
||||
entity_ids = await async_extract_entity_ids(hass, service_call, expand_group)
|
||||
|
||||
return [
|
||||
entity
|
||||
for entity in entities
|
||||
if entity.available and entity.entity_id in entity_ids
|
||||
]
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_extract_entity_ids(hass, service_call, expand_group=True):
|
||||
"""Extract a list of entity ids from a service call.
|
||||
|
||||
Will convert group entity ids to the entity ids it represents.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
entity_ids = service_call.data.get(ATTR_ENTITY_ID)
|
||||
area_ids = service_call.data.get(ATTR_AREA_ID)
|
||||
|
@ -244,9 +262,7 @@ def async_set_service_schema(hass, domain, service, schema):
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def entity_service_call(
|
||||
hass, platforms, func, call, service_name="", required_features=None
|
||||
):
|
||||
async def entity_service_call(hass, platforms, func, call, required_features=None):
|
||||
"""Handle an entity service call.
|
||||
|
||||
Calls all platforms simultaneously.
|
||||
|
|
|
@ -23,6 +23,7 @@ import homeassistant.helpers.config_validation as cv
|
|||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockEntity,
|
||||
get_test_home_assistant,
|
||||
mock_coro,
|
||||
mock_device_registry,
|
||||
|
@ -64,6 +65,54 @@ def mock_entities():
|
|||
return entities
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def area_mock(hass):
|
||||
"""Mock including area info."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set("light.Kitchen", STATE_OFF)
|
||||
|
||||
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
|
||||
device_no_area = dev_reg.DeviceEntry()
|
||||
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
|
||||
|
||||
mock_device_registry(
|
||||
hass,
|
||||
{
|
||||
device_in_area.id: device_in_area,
|
||||
device_no_area.id: device_no_area,
|
||||
device_diff_area.id: device_diff_area,
|
||||
},
|
||||
)
|
||||
|
||||
entity_in_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.in_area",
|
||||
unique_id="in-area-id",
|
||||
platform="test",
|
||||
device_id=device_in_area.id,
|
||||
)
|
||||
entity_no_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.no_area",
|
||||
unique_id="no-area-id",
|
||||
platform="test",
|
||||
device_id=device_no_area.id,
|
||||
)
|
||||
entity_diff_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.diff_area",
|
||||
unique_id="diff-area-id",
|
||||
platform="test",
|
||||
device_id=device_diff_area.id,
|
||||
)
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
entity_in_area.entity_id: entity_in_area,
|
||||
entity_no_area.entity_id: entity_no_area,
|
||||
entity_diff_area.entity_id: entity_diff_area,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestServiceHelpers(unittest.TestCase):
|
||||
"""Test the Home Assistant service helpers."""
|
||||
|
||||
|
@ -204,52 +253,8 @@ async def test_extract_entity_ids(hass):
|
|||
)
|
||||
|
||||
|
||||
async def test_extract_entity_ids_from_area(hass):
|
||||
async def test_extract_entity_ids_from_area(hass, area_mock):
|
||||
"""Test extract_entity_ids method with areas."""
|
||||
hass.states.async_set("light.Bowl", STATE_ON)
|
||||
hass.states.async_set("light.Ceiling", STATE_OFF)
|
||||
hass.states.async_set("light.Kitchen", STATE_OFF)
|
||||
|
||||
device_in_area = dev_reg.DeviceEntry(area_id="test-area")
|
||||
device_no_area = dev_reg.DeviceEntry()
|
||||
device_diff_area = dev_reg.DeviceEntry(area_id="diff-area")
|
||||
|
||||
mock_device_registry(
|
||||
hass,
|
||||
{
|
||||
device_in_area.id: device_in_area,
|
||||
device_no_area.id: device_no_area,
|
||||
device_diff_area.id: device_diff_area,
|
||||
},
|
||||
)
|
||||
|
||||
entity_in_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.in_area",
|
||||
unique_id="in-area-id",
|
||||
platform="test",
|
||||
device_id=device_in_area.id,
|
||||
)
|
||||
entity_no_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.no_area",
|
||||
unique_id="no-area-id",
|
||||
platform="test",
|
||||
device_id=device_no_area.id,
|
||||
)
|
||||
entity_diff_area = ent_reg.RegistryEntry(
|
||||
entity_id="light.diff_area",
|
||||
unique_id="diff-area-id",
|
||||
platform="test",
|
||||
device_id=device_diff_area.id,
|
||||
)
|
||||
mock_registry(
|
||||
hass,
|
||||
{
|
||||
entity_in_area.entity_id: entity_in_area,
|
||||
entity_no_area.entity_id: entity_no_area,
|
||||
entity_diff_area.entity_id: entity_diff_area,
|
||||
},
|
||||
)
|
||||
|
||||
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
|
||||
|
||||
assert {"light.in_area"} == await service.async_extract_entity_ids(hass, call)
|
||||
|
@ -678,3 +683,86 @@ async def test_domain_control_no_user(hass, mock_entities):
|
|||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
async def test_extract_from_service_available_device(hass):
|
||||
"""Test the extraction of entity from service and device is available."""
|
||||
entities = [
|
||||
MockEntity(name="test_1", entity_id="test_domain.test_1"),
|
||||
MockEntity(name="test_2", entity_id="test_domain.test_2", available=False),
|
||||
MockEntity(name="test_3", entity_id="test_domain.test_3"),
|
||||
MockEntity(name="test_4", entity_id="test_domain.test_4", available=False),
|
||||
]
|
||||
|
||||
call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL})
|
||||
|
||||
assert ["test_domain.test_1", "test_domain.test_3"] == [
|
||||
ent.entity_id
|
||||
for ent in (await service.async_extract_entities(hass, entities, call_1))
|
||||
]
|
||||
|
||||
call_2 = ha.ServiceCall(
|
||||
"test",
|
||||
"service",
|
||||
data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]},
|
||||
)
|
||||
|
||||
assert ["test_domain.test_3"] == [
|
||||
ent.entity_id
|
||||
for ent in (await service.async_extract_entities(hass, entities, call_2))
|
||||
]
|
||||
|
||||
|
||||
async def test_extract_from_service_empty_if_no_entity_id(hass):
|
||||
"""Test the extraction from service without specifying entity."""
|
||||
entities = [
|
||||
MockEntity(name="test_1", entity_id="test_domain.test_1"),
|
||||
MockEntity(name="test_2", entity_id="test_domain.test_2"),
|
||||
]
|
||||
call = ha.ServiceCall("test", "service")
|
||||
|
||||
assert [] == [
|
||||
ent.entity_id
|
||||
for ent in (await service.async_extract_entities(hass, entities, call))
|
||||
]
|
||||
|
||||
|
||||
async def test_extract_from_service_filter_out_non_existing_entities(hass):
|
||||
"""Test the extraction of non existing entities from service."""
|
||||
entities = [
|
||||
MockEntity(name="test_1", entity_id="test_domain.test_1"),
|
||||
MockEntity(name="test_2", entity_id="test_domain.test_2"),
|
||||
]
|
||||
|
||||
call = ha.ServiceCall(
|
||||
"test",
|
||||
"service",
|
||||
{"entity_id": ["test_domain.test_2", "test_domain.non_exist"]},
|
||||
)
|
||||
|
||||
assert ["test_domain.test_2"] == [
|
||||
ent.entity_id
|
||||
for ent in (await service.async_extract_entities(hass, entities, call))
|
||||
]
|
||||
|
||||
|
||||
async def test_extract_from_service_area_id(hass, area_mock):
|
||||
"""Test the extraction using area ID as reference."""
|
||||
entities = [
|
||||
MockEntity(name="in_area", entity_id="light.in_area"),
|
||||
MockEntity(name="no_area", entity_id="light.no_area"),
|
||||
MockEntity(name="diff_area", entity_id="light.diff_area"),
|
||||
]
|
||||
|
||||
call = ha.ServiceCall("light", "turn_on", {"area_id": "test-area"})
|
||||
extracted = await service.async_extract_entities(hass, entities, call)
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0].entity_id == "light.in_area"
|
||||
|
||||
call = ha.ServiceCall("light", "turn_on", {"area_id": ["test-area", "diff-area"]})
|
||||
extracted = await service.async_extract_entities(hass, entities, call)
|
||||
assert len(extracted) == 2
|
||||
assert sorted(ent.entity_id for ent in extracted) == [
|
||||
"light.diff_area",
|
||||
"light.in_area",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue