Rename wake_word.async_default_engine to wake_word.async_default_entity (#100855)

* Rename wake_word.async_default_engine to wake_word.async_default_entity

* tweak

* Some more rename

* Update tests
pull/100863/head
Erik Montnemery 2023-09-25 17:08:37 +02:00 committed by GitHub
parent 8ed0f05270
commit 803d24ad1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 26 additions and 26 deletions

View File

@ -413,8 +413,8 @@ class PipelineRun:
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
tts_options: dict | None = field(init=False, default=None)
wake_word_engine: str = field(init=False)
wake_word_provider: wake_word.WakeWordDetectionEntity = field(init=False)
wake_word_entity_id: str = field(init=False)
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False)
debug_recording_thread: Thread | None = None
"""Thread that records audio to debug_recording_dir"""
@ -476,24 +476,24 @@ class PipelineRun:
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
engine = wake_word.async_default_engine(self.hass)
if engine is None:
entity_id = wake_word.async_default_entity(self.hass)
if entity_id is None:
raise WakeWordDetectionError(
code="wake-engine-missing",
message="No wake word engine",
)
wake_word_provider = wake_word.async_get_wake_word_detection_entity(
self.hass, engine
wake_word_entity = wake_word.async_get_wake_word_detection_entity(
self.hass, entity_id
)
if wake_word_provider is None:
if wake_word_entity is None:
raise WakeWordDetectionError(
code="wake-provider-missing",
message=f"No wake-word-detection provider for: {engine}",
message=f"No wake-word-detection provider for: {entity_id}",
)
self.wake_word_engine = engine
self.wake_word_provider = wake_word_provider
self.wake_word_entity_id = entity_id
self.wake_word_entity = wake_word_entity
async def wake_word_detection(
self,
@ -519,14 +519,14 @@ class PipelineRun:
PipelineEvent(
PipelineEventType.WAKE_WORD_START,
{
"engine": self.wake_word_engine,
"entity_id": self.wake_word_entity_id,
"metadata": metadata_dict,
},
)
)
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_engine}")
self.debug_recording_queue.put_nowait(f"00_wake-{self.wake_word_entity_id}")
wake_word_settings = self.wake_word_settings or WakeWordSettings()
@ -548,7 +548,7 @@ class PipelineRun:
try:
# Detect wake word(s)
result = await self.wake_word_provider.async_process_audio_stream(
result = await self.wake_word_entity.async_process_audio_stream(
self._wake_word_audio_stream(
audio_stream=stream,
stt_audio_buffer=stt_audio_buffer,

View File

@ -19,7 +19,7 @@ from .const import DOMAIN
from .models import DetectionResult, WakeWord
__all__ = [
"async_default_engine",
"async_default_entity",
"async_get_wake_word_detection_entity",
"DetectionResult",
"DOMAIN",
@ -33,8 +33,8 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
"""Return the domain or entity id of the default engine."""
def async_default_entity(hass: HomeAssistant) -> str | None:
"""Return the entity id of the default engine."""
return next(iter(hass.states.async_entity_ids(DOMAIN)), None)

View File

@ -277,7 +277,7 @@
}),
dict({
'data': dict({
'engine': 'wake_word.test',
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,

View File

@ -185,7 +185,7 @@
# ---
# name: test_audio_pipeline_with_wake_word.1
dict({
'engine': 'wake_word.test',
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -284,7 +284,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_no_timeout.1
dict({
'engine': 'wake_word.test',
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
@ -385,7 +385,7 @@
# ---
# name: test_audio_pipeline_with_wake_word_timeout.1
dict({
'engine': 'wake_word.test',
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': 16,
'channel': 1,

View File

@ -337,7 +337,7 @@ async def test_audio_pipeline_no_wake_word_engine(
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine", return_value=None
"homeassistant.components.wake_word.async_default_entity", return_value=None
):
await client.send_json_auto_id(
{
@ -367,7 +367,7 @@ async def test_audio_pipeline_no_wake_word_entity(
client = await hass_ws_client(hass)
with patch(
"homeassistant.components.wake_word.async_default_engine",
"homeassistant.components.wake_word.async_default_entity",
return_value="wake_word.bad-entity-id",
), patch(
"homeassistant.components.wake_word.async_get_wake_word_detection_entity",

View File

@ -207,20 +207,20 @@ async def test_not_detected_entity(
async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test async_default_engine."""
"""Test async_default_entity."""
assert await async_setup_component(hass, wake_word.DOMAIN, {wake_word.DOMAIN: {}})
await hass.async_block_till_done()
assert wake_word.async_default_engine(hass) is None
assert wake_word.async_default_entity(hass) is None
async def test_default_engine_entity(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockProviderEntity
) -> None:
"""Test async_default_engine."""
"""Test async_default_entity."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert wake_word.async_default_engine(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
assert wake_word.async_default_entity(hass) == f"{wake_word.DOMAIN}.{TEST_DOMAIN}"
async def test_get_engine_entity(