Track entity sources (#37258)
Co-authored-by: David Mulcahey <david.mulcahey@me.com>pull/39041/head
parent
24a16ff8fe
commit
3dc79aa60a
|
@ -77,11 +77,6 @@ class DeconzDevice(DeconzBase, Entity):
|
|||
self.hass, self.gateway.signal_reachable, self.async_update_callback
|
||||
)
|
||||
)
|
||||
self.listeners.append(
|
||||
async_dispatcher_connect(
|
||||
self.hass, self.gateway.signal_remove_entity, self.async_remove_self
|
||||
)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Disconnect device object when removed."""
|
||||
|
@ -91,15 +86,6 @@ class DeconzDevice(DeconzBase, Entity):
|
|||
for unsub_dispatcher in self.listeners:
|
||||
unsub_dispatcher()
|
||||
|
||||
async def async_remove_self(self, deconz_ids: list) -> None:
|
||||
"""Schedule removal of this entity.
|
||||
|
||||
Called by signal_remove_entity scheduled by async_added_to_hass.
|
||||
"""
|
||||
if self._device.deconz_id not in deconz_ids:
|
||||
return
|
||||
await self.async_remove()
|
||||
|
||||
@callback
|
||||
def async_update_callback(self, force_update=False, ignore_update=False):
|
||||
"""Update the device's state."""
|
||||
|
|
|
@ -164,15 +164,14 @@ class DeconzGateway:
|
|||
else:
|
||||
deconz_ids += [group.deconz_id for group in groups]
|
||||
|
||||
if deconz_ids:
|
||||
async_dispatcher_send(self.hass, self.signal_remove_entity, deconz_ids)
|
||||
|
||||
entity_registry = await self.hass.helpers.entity_registry.async_get_registry()
|
||||
|
||||
for entity_id, deconz_id in self.deconz_ids.items():
|
||||
if deconz_id in deconz_ids and entity_registry.async_is_registered(
|
||||
entity_id
|
||||
):
|
||||
# Removing an entity from the entity registry will also remove them
|
||||
# from Home Assistant
|
||||
entity_registry.async_remove(entity_id)
|
||||
|
||||
@property
|
||||
|
@ -197,11 +196,6 @@ class DeconzGateway:
|
|||
}
|
||||
return new_device[device_type]
|
||||
|
||||
@property
|
||||
def signal_remove_entity(self) -> str:
|
||||
"""Gateway specific event to signal removal of entity."""
|
||||
return f"deconz-remove-{self.bridgeid}"
|
||||
|
||||
@callback
|
||||
def async_add_device_callback(self, device_type, device) -> None:
|
||||
"""Handle event of new device creation in deCONZ."""
|
||||
|
|
|
@ -27,9 +27,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
|
||||
# Get Dyson Devices from parent component
|
||||
device_ids = [device.unique_id for device in hass.data[DYSON_AIQ_DEVICES]]
|
||||
new_entities = []
|
||||
for device in hass.data[DYSON_DEVICES]:
|
||||
print(device.serial)
|
||||
if isinstance(device, DysonPureCool) and device.serial not in device_ids:
|
||||
hass.data[DYSON_AIQ_DEVICES].append(DysonAirSensor(device))
|
||||
new_entities.append(DysonAirSensor(device))
|
||||
|
||||
if not new_entities:
|
||||
return
|
||||
|
||||
hass.data[DYSON_AIQ_DEVICES].extend(new_entities)
|
||||
add_entities(hass.data[DYSON_AIQ_DEVICES])
|
||||
|
||||
|
||||
|
|
|
@ -41,18 +41,24 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||
|
||||
# Get Dyson Devices from parent component
|
||||
device_ids = [device.unique_id for device in hass.data[DYSON_SENSOR_DEVICES]]
|
||||
new_entities = []
|
||||
for device in hass.data[DYSON_DEVICES]:
|
||||
if isinstance(device, DysonPureCool):
|
||||
if f"{device.serial}-temperature" not in device_ids:
|
||||
devices.append(DysonTemperatureSensor(device, unit))
|
||||
new_entities.append(DysonTemperatureSensor(device, unit))
|
||||
if f"{device.serial}-humidity" not in device_ids:
|
||||
devices.append(DysonHumiditySensor(device))
|
||||
new_entities.append(DysonHumiditySensor(device))
|
||||
elif isinstance(device, DysonPureCoolLink):
|
||||
devices.append(DysonFilterLifeSensor(device))
|
||||
devices.append(DysonDustSensor(device))
|
||||
devices.append(DysonHumiditySensor(device))
|
||||
devices.append(DysonTemperatureSensor(device, unit))
|
||||
devices.append(DysonAirQualitySensor(device))
|
||||
new_entities.append(DysonFilterLifeSensor(device))
|
||||
new_entities.append(DysonDustSensor(device))
|
||||
new_entities.append(DysonHumiditySensor(device))
|
||||
new_entities.append(DysonTemperatureSensor(device, unit))
|
||||
new_entities.append(DysonAirQualitySensor(device))
|
||||
|
||||
if not new_entities:
|
||||
return
|
||||
|
||||
devices.extend(new_entities)
|
||||
add_entities(devices)
|
||||
|
||||
|
||||
|
|
|
@ -3,11 +3,12 @@ import asyncio
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.permissions.const import POLICY_READ
|
||||
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
|
||||
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
|
||||
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
|
||||
from homeassistant.core import DOMAIN as HASS_DOMAIN, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, Unauthorized
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers import config_validation as cv, entity
|
||||
from homeassistant.helpers.event import async_track_state_change_event
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
|
@ -30,6 +31,7 @@ def async_register_commands(hass, async_reg):
|
|||
async_reg(hass, handle_render_template)
|
||||
async_reg(hass, handle_manifest_list)
|
||||
async_reg(hass, handle_manifest_get)
|
||||
async_reg(hass, handle_entity_source)
|
||||
|
||||
|
||||
def pong_message(iden):
|
||||
|
@ -263,3 +265,46 @@ def handle_render_template(hass, connection, msg):
|
|||
|
||||
connection.send_result(msg["id"])
|
||||
state_listener()
|
||||
|
||||
|
||||
@callback
|
||||
@decorators.websocket_command(
|
||||
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
|
||||
)
|
||||
def handle_entity_source(hass, connection, msg):
|
||||
"""Handle entity source command."""
|
||||
raw_sources = entity.entity_sources(hass)
|
||||
entity_perm = connection.user.permissions.check_entity
|
||||
|
||||
if "entity_id" not in msg:
|
||||
if connection.user.permissions.access_all_entities("read"):
|
||||
sources = raw_sources
|
||||
else:
|
||||
sources = {
|
||||
entity_id: source
|
||||
for entity_id, source in raw_sources.items()
|
||||
if entity_perm(entity_id, "read")
|
||||
}
|
||||
|
||||
connection.send_message(messages.result_message(msg["id"], sources))
|
||||
return
|
||||
|
||||
sources = {}
|
||||
|
||||
for entity_id in msg["entity_id"]:
|
||||
if not entity_perm(entity_id, "read"):
|
||||
raise Unauthorized(
|
||||
context=connection.context(msg),
|
||||
permission=POLICY_READ,
|
||||
perm_category=CAT_ENTITIES,
|
||||
)
|
||||
|
||||
source = raw_sources.get(entity_id)
|
||||
|
||||
if source is None:
|
||||
connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found")
|
||||
return
|
||||
|
||||
sources[entity_id] = source
|
||||
|
||||
connection.send_result(msg["id"], sources)
|
||||
|
|
|
@ -274,7 +274,6 @@ SIGNAL_REMOVE = "remove"
|
|||
SIGNAL_SET_LEVEL = "set_level"
|
||||
SIGNAL_STATE_ATTR = "update_state_attribute"
|
||||
SIGNAL_UPDATE_DEVICE = "{}_zha_update_device"
|
||||
SIGNAL_REMOVE_GROUP = "remove_group"
|
||||
SIGNAL_GROUP_ENTITY_REMOVED = "group_entity_removed"
|
||||
SIGNAL_GROUP_MEMBERSHIP_CHANGE = "group_membership_change"
|
||||
|
||||
|
|
|
@ -57,7 +57,6 @@ from .const import (
|
|||
SIGNAL_ADD_ENTITIES,
|
||||
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
|
||||
SIGNAL_REMOVE,
|
||||
SIGNAL_REMOVE_GROUP,
|
||||
UNKNOWN_MANUFACTURER,
|
||||
UNKNOWN_MODEL,
|
||||
ZHA_GW_MSG,
|
||||
|
@ -298,13 +297,10 @@ class ZHAGateway:
|
|||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
|
||||
|
||||
def group_removed(self, zigpy_group: ZigpyGroupType) -> None:
|
||||
"""Handle zigpy group added event."""
|
||||
"""Handle zigpy group removed event."""
|
||||
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
|
||||
zha_group = self._groups.pop(zigpy_group.group_id, None)
|
||||
zha_group.info("group_removed")
|
||||
async_dispatcher_send(
|
||||
self._hass, f"{SIGNAL_REMOVE_GROUP}_0x{zigpy_group.group_id:04x}"
|
||||
)
|
||||
self._cleanup_group_entity_registry_entries(zigpy_group)
|
||||
|
||||
def _send_group_gateway_message(
|
||||
|
@ -619,7 +615,7 @@ class ZHAGateway:
|
|||
if not group:
|
||||
_LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id)
|
||||
return
|
||||
if group and group.members:
|
||||
if group.members:
|
||||
tasks = []
|
||||
for member in group.members:
|
||||
tasks.append(member.async_remove_from_group())
|
||||
|
|
|
@ -24,7 +24,6 @@ from .core.const import (
|
|||
SIGNAL_GROUP_ENTITY_REMOVED,
|
||||
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
|
||||
SIGNAL_REMOVE,
|
||||
SIGNAL_REMOVE_GROUP,
|
||||
)
|
||||
from .core.helpers import LogMixin
|
||||
from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType
|
||||
|
@ -217,32 +216,35 @@ class ZhaGroupEntity(BaseZhaEntity):
|
|||
"""Initialize a light group."""
|
||||
super().__init__(unique_id, zha_device, **kwargs)
|
||||
self._available = False
|
||||
self._name = (
|
||||
f"{zha_device.gateway.groups.get(group_id).name}_zha_group_0x{group_id:04x}"
|
||||
)
|
||||
self._group = zha_device.gateway.groups.get(group_id)
|
||||
self._name = f"{self._group.name}_zha_group_0x{group_id:04x}"
|
||||
self._group_id: int = group_id
|
||||
self._entity_ids: List[str] = entity_ids
|
||||
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
|
||||
self._handled_group_membership = False
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return entity availability."""
|
||||
return self._available
|
||||
|
||||
async def _handle_group_membership_changed(self):
|
||||
"""Handle group membership changed."""
|
||||
# Make sure we don't call remove twice as members are removed
|
||||
if self._handled_group_membership:
|
||||
return
|
||||
|
||||
self._handled_group_membership = True
|
||||
await self.async_remove()
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register callbacks."""
|
||||
await super().async_added_to_hass()
|
||||
self.async_accept_signal(
|
||||
None,
|
||||
f"{SIGNAL_REMOVE_GROUP}_0x{self._group_id:04x}",
|
||||
self.async_remove,
|
||||
signal_override=True,
|
||||
)
|
||||
|
||||
self.async_accept_signal(
|
||||
None,
|
||||
f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{self._group_id:04x}",
|
||||
self.async_remove,
|
||||
self._handle_group_membership_changed,
|
||||
signal_override=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -54,6 +54,10 @@ class Unauthorized(HomeAssistantError):
|
|||
"""Unauthorized error."""
|
||||
super().__init__(self.__class__.__name__)
|
||||
self.context = context
|
||||
|
||||
if user_id is None and context is not None:
|
||||
user_id = context.user_id
|
||||
|
||||
self.user_id = user_id
|
||||
self.entity_id = entity_id
|
||||
self.config_entry_id = config_entry_id
|
||||
|
|
|
@ -25,15 +25,26 @@ from homeassistant.const import (
|
|||
TEMP_FAHRENHEIT,
|
||||
)
|
||||
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import NoEntitySpecifiedError
|
||||
from homeassistant.exceptions import HomeAssistantError, NoEntitySpecifiedError
|
||||
from homeassistant.helpers.entity_platform import EntityPlatform
|
||||
from homeassistant.helpers.entity_registry import RegistryEntry
|
||||
from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event
|
||||
from homeassistant.helpers.typing import StateType
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
SLOW_UPDATE_WARNING = 10
|
||||
DATA_ENTITY_SOURCE = "entity_info"
|
||||
SOURCE_CONFIG_ENTRY = "config_entry"
|
||||
SOURCE_PLATFORM_CONFIG = "platform_config"
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def entity_sources(hass: HomeAssistant) -> Dict[str, Dict[str, str]]:
|
||||
"""Get the entity sources."""
|
||||
return hass.data.get(DATA_ENTITY_SOURCE, {})
|
||||
|
||||
|
||||
def generate_entity_id(
|
||||
|
@ -109,6 +120,9 @@ class Entity(ABC):
|
|||
_context: Optional[Context] = None
|
||||
_context_set: Optional[datetime] = None
|
||||
|
||||
# If entity is added to an entity platform
|
||||
_added = False
|
||||
|
||||
@property
|
||||
def should_poll(self) -> bool:
|
||||
"""Return True if entity has to be polled for state.
|
||||
|
@ -477,10 +491,49 @@ class Entity(ABC):
|
|||
To be extended by integrations.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def add_to_platform_start(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
platform: EntityPlatform,
|
||||
parallel_updates: Optional[asyncio.Semaphore],
|
||||
) -> None:
|
||||
"""Start adding an entity to a platform."""
|
||||
if self._added:
|
||||
raise HomeAssistantError(
|
||||
f"Entity {self.entity_id} cannot be added a second time to an entity platform"
|
||||
)
|
||||
|
||||
self.hass = hass
|
||||
self.platform = platform
|
||||
self.parallel_updates = parallel_updates
|
||||
self._added = True
|
||||
|
||||
@callback
|
||||
def add_to_platform_abort(self) -> None:
|
||||
"""Abort adding an entity to a platform."""
|
||||
self.hass = None
|
||||
self.platform = None
|
||||
self.parallel_updates = None
|
||||
self._added = False
|
||||
|
||||
async def add_to_platform_finish(self) -> None:
|
||||
"""Finish adding an entity to a platform."""
|
||||
await self.async_internal_added_to_hass()
|
||||
await self.async_added_to_hass()
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_remove(self) -> None:
|
||||
"""Remove entity from Home Assistant."""
|
||||
assert self.hass is not None
|
||||
|
||||
if self.platform and not self._added:
|
||||
raise HomeAssistantError(
|
||||
f"Entity {self.entity_id} async_remove called twice"
|
||||
)
|
||||
|
||||
self._added = False
|
||||
|
||||
if self._on_remove is not None:
|
||||
while self._on_remove:
|
||||
self._on_remove.pop()()
|
||||
|
@ -507,8 +560,25 @@ class Entity(ABC):
|
|||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
assert self.hass is not None
|
||||
|
||||
if self.platform:
|
||||
info = {"domain": self.platform.platform_name}
|
||||
|
||||
if self.platform.config_entry:
|
||||
info["source"] = SOURCE_CONFIG_ENTRY
|
||||
info["config_entry"] = self.platform.config_entry.entry_id
|
||||
else:
|
||||
info["source"] = SOURCE_PLATFORM_CONFIG
|
||||
|
||||
self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info
|
||||
|
||||
if self.registry_entry is not None:
|
||||
assert self.hass is not None
|
||||
# This is an assert as it should never happen, but helps in tests
|
||||
assert (
|
||||
not self.registry_entry.disabled_by
|
||||
), f"Entity {self.entity_id} is being added while it's disabled"
|
||||
|
||||
self.async_on_remove(
|
||||
async_track_entity_registry_updated_event(
|
||||
self.hass, self.entity_id, self._async_registry_updated
|
||||
|
@ -520,6 +590,9 @@ class Entity(ABC):
|
|||
|
||||
Not to be extended by integrations.
|
||||
"""
|
||||
if self.platform:
|
||||
assert self.hass is not None
|
||||
self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id)
|
||||
|
||||
async def _async_registry_updated(self, event: Event) -> None:
|
||||
"""Handle entity registry update."""
|
||||
|
|
|
@ -6,7 +6,7 @@ from logging import Logger
|
|||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import DEVICE_DEFAULT_NAME
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
|
@ -60,7 +60,7 @@ class EntityPlatform:
|
|||
self.platform = platform
|
||||
self.scan_interval = scan_interval
|
||||
self.entity_namespace = entity_namespace
|
||||
self.config_entry: Optional[ConfigEntry] = None
|
||||
self.config_entry: Optional[config_entries.ConfigEntry] = None
|
||||
self.entities: Dict[str, Entity] = {} # pylint: disable=used-before-assignment
|
||||
self._tasks: List[asyncio.Future] = []
|
||||
# Method to cancel the state change listener
|
||||
|
@ -149,7 +149,7 @@ class EntityPlatform:
|
|||
|
||||
await self._async_setup_platform(async_create_setup_task)
|
||||
|
||||
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(self, config_entry: config_entries.ConfigEntry) -> bool:
|
||||
"""Set up the platform from a config entry."""
|
||||
# Store it so that we can save config entry ID in entity registry
|
||||
self.config_entry = config_entry
|
||||
|
@ -332,10 +332,10 @@ class EntityPlatform:
|
|||
if entity is None:
|
||||
raise ValueError("Entity cannot be None")
|
||||
|
||||
entity.hass = self.hass
|
||||
entity.platform = self
|
||||
entity.parallel_updates = self._get_parallel_updates_semaphore(
|
||||
hasattr(entity, "async_update")
|
||||
entity.add_to_platform_start(
|
||||
self.hass,
|
||||
self,
|
||||
self._get_parallel_updates_semaphore(hasattr(entity, "async_update")),
|
||||
)
|
||||
|
||||
# Update properties before we generate the entity_id
|
||||
|
@ -344,8 +344,7 @@ class EntityPlatform:
|
|||
await entity.async_device_update(warning=False)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.logger.exception("%s: Error on device update!", self.platform_name)
|
||||
entity.hass = None
|
||||
entity.platform = None
|
||||
entity.add_to_platform_abort()
|
||||
return
|
||||
|
||||
requested_entity_id = None
|
||||
|
@ -423,8 +422,7 @@ class EntityPlatform:
|
|||
or entity.name
|
||||
or f'"{self.platform_name} {entity.unique_id}"',
|
||||
)
|
||||
entity.hass = None
|
||||
entity.platform = None
|
||||
entity.add_to_platform_abort()
|
||||
return
|
||||
|
||||
# We won't generate an entity ID if the platform has already set one
|
||||
|
@ -450,8 +448,7 @@ class EntityPlatform:
|
|||
|
||||
# Make sure it is valid in case an entity set the value themselves
|
||||
if not valid_entity_id(entity.entity_id):
|
||||
entity.hass = None
|
||||
entity.platform = None
|
||||
entity.add_to_platform_abort()
|
||||
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
|
||||
|
||||
already_exists = entity.entity_id in self.entities
|
||||
|
@ -472,18 +469,14 @@ class EntityPlatform:
|
|||
else:
|
||||
msg = f"Entity id already exists - ignoring: {entity.entity_id}"
|
||||
self.logger.error(msg)
|
||||
entity.hass = None
|
||||
entity.platform = None
|
||||
entity.add_to_platform_abort()
|
||||
return
|
||||
|
||||
entity_id = entity.entity_id
|
||||
self.entities[entity_id] = entity
|
||||
entity.async_on_remove(lambda: self.entities.pop(entity_id))
|
||||
|
||||
await entity.async_internal_added_to_hass()
|
||||
await entity.async_added_to_hass()
|
||||
|
||||
entity.async_write_ha_state()
|
||||
await entity.add_to_platform_finish()
|
||||
|
||||
async def async_reset(self) -> None:
|
||||
"""Remove all entities and reset data.
|
||||
|
|
|
@ -1,85 +1,55 @@
|
|||
"""The tests for the Switch component."""
|
||||
# pylint: disable=protected-access
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from homeassistant import core
|
||||
from homeassistant.components import switch
|
||||
from homeassistant.const import CONF_PLATFORM
|
||||
from homeassistant.setup import async_setup_component, setup_component
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import get_test_home_assistant, mock_entity_platform
|
||||
from tests.components.switch import common
|
||||
|
||||
|
||||
class TestSwitch(unittest.TestCase):
|
||||
"""Test the switch module."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def setUp(self):
|
||||
"""Set up things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
platform = getattr(self.hass.components, "test.switch")
|
||||
platform.init()
|
||||
# Switch 1 is ON, switch 2 is OFF
|
||||
self.switch_1, self.switch_2, self.switch_3 = platform.ENTITIES
|
||||
self.addCleanup(self.hass.stop)
|
||||
|
||||
def test_methods(self):
|
||||
"""Test is_on, turn_on, turn_off methods."""
|
||||
assert setup_component(
|
||||
self.hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}}
|
||||
)
|
||||
self.hass.block_till_done()
|
||||
assert switch.is_on(self.hass, self.switch_1.entity_id)
|
||||
assert not switch.is_on(self.hass, self.switch_2.entity_id)
|
||||
assert not switch.is_on(self.hass, self.switch_3.entity_id)
|
||||
|
||||
common.turn_off(self.hass, self.switch_1.entity_id)
|
||||
common.turn_on(self.hass, self.switch_2.entity_id)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not switch.is_on(self.hass, self.switch_1.entity_id)
|
||||
assert switch.is_on(self.hass, self.switch_2.entity_id)
|
||||
|
||||
# Turn all off
|
||||
common.turn_off(self.hass)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert not switch.is_on(self.hass, self.switch_1.entity_id)
|
||||
assert not switch.is_on(self.hass, self.switch_2.entity_id)
|
||||
assert not switch.is_on(self.hass, self.switch_3.entity_id)
|
||||
|
||||
# Turn all on
|
||||
common.turn_on(self.hass)
|
||||
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert switch.is_on(self.hass, self.switch_1.entity_id)
|
||||
assert switch.is_on(self.hass, self.switch_2.entity_id)
|
||||
assert switch.is_on(self.hass, self.switch_3.entity_id)
|
||||
|
||||
def test_setup_two_platforms(self):
|
||||
"""Test with bad configuration."""
|
||||
# Test if switch component returns 0 switches
|
||||
test_platform = getattr(self.hass.components, "test.switch")
|
||||
test_platform.init(True)
|
||||
|
||||
mock_entity_platform(self.hass, "switch.test2", test_platform)
|
||||
test_platform.init(False)
|
||||
|
||||
assert setup_component(
|
||||
self.hass,
|
||||
switch.DOMAIN,
|
||||
{
|
||||
switch.DOMAIN: {CONF_PLATFORM: "test"},
|
||||
f"{switch.DOMAIN} 2": {CONF_PLATFORM: "test2"},
|
||||
},
|
||||
)
|
||||
@pytest.fixture(autouse=True)
|
||||
def entities(hass):
|
||||
"""Initialize the test switch."""
|
||||
platform = getattr(hass.components, "test.switch")
|
||||
platform.init()
|
||||
yield platform.ENTITIES
|
||||
|
||||
|
||||
async def test_switch_context(hass, hass_admin_user):
|
||||
async def test_methods(hass, entities):
|
||||
"""Test is_on, turn_on, turn_off methods."""
|
||||
switch_1, switch_2, switch_3 = entities
|
||||
assert await async_setup_component(
|
||||
hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert switch.is_on(hass, switch_1.entity_id)
|
||||
assert not switch.is_on(hass, switch_2.entity_id)
|
||||
assert not switch.is_on(hass, switch_3.entity_id)
|
||||
|
||||
await common.async_turn_off(hass, switch_1.entity_id)
|
||||
await common.async_turn_on(hass, switch_2.entity_id)
|
||||
|
||||
assert not switch.is_on(hass, switch_1.entity_id)
|
||||
assert switch.is_on(hass, switch_2.entity_id)
|
||||
|
||||
# Turn all off
|
||||
await common.async_turn_off(hass)
|
||||
|
||||
assert not switch.is_on(hass, switch_1.entity_id)
|
||||
assert not switch.is_on(hass, switch_2.entity_id)
|
||||
assert not switch.is_on(hass, switch_3.entity_id)
|
||||
|
||||
# Turn all on
|
||||
await common.async_turn_on(hass)
|
||||
|
||||
assert switch.is_on(hass, switch_1.entity_id)
|
||||
assert switch.is_on(hass, switch_2.entity_id)
|
||||
assert switch.is_on(hass, switch_3.entity_id)
|
||||
|
||||
|
||||
async def test_switch_context(hass, entities, hass_admin_user):
|
||||
"""Test that switch context works."""
|
||||
assert await async_setup_component(hass, "switch", {"switch": {"platform": "test"}})
|
||||
|
||||
|
|
|
@ -10,10 +10,11 @@ from homeassistant.components.websocket_api.auth import (
|
|||
from homeassistant.components.websocket_api.const import URL
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
from tests.common import MockEntity, MockEntityPlatform, async_mock_service
|
||||
|
||||
|
||||
async def test_call_service(hass, websocket_client):
|
||||
|
@ -519,3 +520,116 @@ async def test_manifest_get(hass, websocket_client):
|
|||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "not_found"
|
||||
|
||||
|
||||
async def test_entity_source_admin(hass, websocket_client, hass_admin_user):
|
||||
"""Check that we fetch sources correctly."""
|
||||
platform = MockEntityPlatform(hass)
|
||||
|
||||
await platform.async_add_entities(
|
||||
[MockEntity(name="Entity 1"), MockEntity(name="Entity 2")]
|
||||
)
|
||||
|
||||
# Fetch all
|
||||
await websocket_client.send_json({"id": 6, "type": "entity/source"})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 6
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"test_domain.entity_1": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
"test_domain.entity_2": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
}
|
||||
|
||||
# Fetch one
|
||||
await websocket_client.send_json(
|
||||
{"id": 7, "type": "entity/source", "entity_id": ["test_domain.entity_2"]}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 7
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"test_domain.entity_2": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
}
|
||||
|
||||
# Fetch two
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 8,
|
||||
"type": "entity/source",
|
||||
"entity_id": ["test_domain.entity_2", "test_domain.entity_1"],
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 8
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"test_domain.entity_1": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
"test_domain.entity_2": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
}
|
||||
|
||||
# Fetch non existing
|
||||
await websocket_client.send_json(
|
||||
{
|
||||
"id": 9,
|
||||
"type": "entity/source",
|
||||
"entity_id": ["test_domain.entity_2", "test_domain.non_existing"],
|
||||
}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 9
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == const.ERR_NOT_FOUND
|
||||
|
||||
# Mock policy
|
||||
hass_admin_user.groups = []
|
||||
hass_admin_user.mock_policy(
|
||||
{"entities": {"entity_ids": {"test_domain.entity_2": True}}}
|
||||
)
|
||||
|
||||
# Fetch all
|
||||
await websocket_client.send_json({"id": 10, "type": "entity/source"})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 10
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"test_domain.entity_2": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
}
|
||||
|
||||
# Fetch unauthorized
|
||||
await websocket_client.send_json(
|
||||
{"id": 11, "type": "entity/source", "entity_id": ["test_domain.entity_1"]}
|
||||
)
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg["id"] == 11
|
||||
assert msg["type"] == const.TYPE_RESULT
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == const.ERR_UNAUTHORIZED
|
||||
|
|
|
@ -11,7 +11,13 @@ from homeassistant.core import Context
|
|||
from homeassistant.helpers import entity, entity_registry
|
||||
|
||||
from tests.async_mock import MagicMock, PropertyMock, patch
|
||||
from tests.common import get_test_home_assistant, mock_registry
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockEntity,
|
||||
MockEntityPlatform,
|
||||
get_test_home_assistant,
|
||||
mock_registry,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_entity_id_requires_hass_or_ids():
|
||||
|
@ -603,7 +609,7 @@ async def test_disabled_in_entity_registry(hass):
|
|||
entity_id="hello.world",
|
||||
unique_id="test-unique-id",
|
||||
platform="test-platform",
|
||||
disabled_by="user",
|
||||
disabled_by=None,
|
||||
)
|
||||
registry = mock_registry(hass, {"hello.world": entry})
|
||||
|
||||
|
@ -611,23 +617,24 @@ async def test_disabled_in_entity_registry(hass):
|
|||
ent.hass = hass
|
||||
ent.entity_id = "hello.world"
|
||||
ent.registry_entry = entry
|
||||
ent.platform = MagicMock(platform_name="test-platform")
|
||||
assert ent.enabled is True
|
||||
|
||||
await ent.async_internal_added_to_hass()
|
||||
ent.async_write_ha_state()
|
||||
assert hass.states.get("hello.world") is None
|
||||
ent.add_to_platform_start(hass, MagicMock(platform_name="test-platform"), None)
|
||||
await ent.add_to_platform_finish()
|
||||
assert hass.states.get("hello.world") is not None
|
||||
|
||||
entry2 = registry.async_update_entity("hello.world", disabled_by=None)
|
||||
entry2 = registry.async_update_entity("hello.world", disabled_by="user")
|
||||
await hass.async_block_till_done()
|
||||
assert entry2 != entry
|
||||
assert ent.registry_entry == entry2
|
||||
assert ent.enabled is True
|
||||
assert ent.enabled is False
|
||||
assert hass.states.get("hello.world") is None
|
||||
|
||||
entry3 = registry.async_update_entity("hello.world", disabled_by="user")
|
||||
entry3 = registry.async_update_entity("hello.world", disabled_by=None)
|
||||
await hass.async_block_till_done()
|
||||
assert entry3 != entry2
|
||||
assert ent.registry_entry == entry3
|
||||
assert ent.enabled is False
|
||||
# Entry is no longer updated, entity is no longer tracking changes
|
||||
assert ent.registry_entry == entry2
|
||||
|
||||
|
||||
async def test_capability_attrs(hass):
|
||||
|
@ -690,3 +697,31 @@ async def test_warn_slow_write_state_custom_component(hass, caplog):
|
|||
"(<class 'custom_components.bla.sensor.test_warn_slow_write_state_custom_component.<locals>.CustomComponentEntity'>) "
|
||||
"took 10.000 seconds. Please report it to the custom component author."
|
||||
) in caplog.text
|
||||
|
||||
|
||||
async def test_setup_source(hass):
|
||||
"""Check that we register sources correctly."""
|
||||
platform = MockEntityPlatform(hass)
|
||||
|
||||
entity_platform = MockEntity(name="Platform Config Source")
|
||||
await platform.async_add_entities([entity_platform])
|
||||
|
||||
platform.config_entry = MockConfigEntry()
|
||||
entity_entry = MockEntity(name="Config Entry Source")
|
||||
await platform.async_add_entities([entity_entry])
|
||||
|
||||
assert entity.entity_sources(hass) == {
|
||||
"test_domain.platform_config_source": {
|
||||
"source": entity.SOURCE_PLATFORM_CONFIG,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
"test_domain.config_entry_source": {
|
||||
"source": entity.SOURCE_CONFIG_ENTRY,
|
||||
"config_entry": platform.config_entry.entry_id,
|
||||
"domain": "test_platform",
|
||||
},
|
||||
}
|
||||
|
||||
await platform.async_reset()
|
||||
|
||||
assert entity.entity_sources(hass) == {}
|
||||
|
|
|
@ -29,4 +29,5 @@ async def async_setup_platform(
|
|||
hass, config, async_add_entities_callback, discovery_info=None
|
||||
):
|
||||
"""Return mock entities."""
|
||||
print("YOOO")
|
||||
async_add_entities_callback(ENTITIES)
|
||||
|
|
Loading…
Reference in New Issue