Allow entity names for STT entities (#91932)

* Allow entity names for STT entities

* Fix tests
pull/91940/head
Paulus Schoutsen 2023-04-23 23:06:34 -04:00 committed by GitHub
parent fba7c6cacd
commit a203149133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 50 deletions

View File

@ -304,7 +304,10 @@ class PipelineRun:
if self.stt_provider is None:
raise RuntimeError("Speech to text was not prepared")
if isinstance(self.stt_provider, stt.Provider):
engine = self.stt_provider.name
else:
engine = self.stt_provider.entity_id
self.process_event(
PipelineEvent(

View File

@ -44,6 +44,8 @@ async def async_setup_entry(
class DemoProviderEntity(SpeechToTextEntity):
"""Demo speech API provider entity."""
_attr_name = "Demo STT"
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""

View File

@ -128,16 +128,6 @@ class SpeechToTextEntity(RestoreEntity):
_attr_should_poll = False
__last_processed: str | None = None
@property
@final
def name(self) -> str:
"""Return the name of the provider entity."""
# Only one entity is allowed per platform for now.
if self.platform is None:
raise RuntimeError("Entity is not added to hass yet.")
return self.platform.platform_name
@property
@final
def state(self) -> str | None:
@ -249,11 +239,7 @@ class SpeechToTextView(HomeAssistantView):
hass: HomeAssistant = request.app["hass"]
provider_entity: SpeechToTextEntity | None = None
if (
not (
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
and provider not in self.providers
):
raise HTTPNotFound()
@ -292,11 +278,7 @@ class SpeechToTextView(HomeAssistantView):
"""Return provider specific audio information."""
hass: HomeAssistant = request.app["hass"]
if (
not (
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
and provider not in self.providers
):
raise HTTPNotFound()

View File

@ -89,6 +89,8 @@ class MockSttProvider(BaseProvider, stt.Provider):
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity."""
_attr_name = "Mock STT"
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""

View File

@ -97,7 +97,7 @@
}),
dict({
'data': dict({
'engine': 'test',
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,

View File

@ -1,10 +1,12 @@
"""The tests for the demo stt component."""
from http import HTTPStatus
from unittest.mock import patch
import pytest
from homeassistant.components import stt
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -24,6 +26,10 @@ async def setup_config_entry(hass: HomeAssistant) -> None:
"""Set up demo component from config entry."""
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
config_entry.add_to_hass(hass)
with patch(
"homeassistant.components.demo.COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM",
[Platform.STT],
):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
@ -103,7 +109,7 @@ async def test_config_entry_demo_speech(
client = await hass_client()
response = await client.post(
"/api/stt/demo",
"/api/stt/stt.demo_stt",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"

View File

@ -93,10 +93,15 @@ class BaseProvider:
class MockProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
@pytest.fixture
def mock_provider() -> MockProvider:
@ -128,15 +133,19 @@ async def setup_fixture(
hass: HomeAssistant,
tmp_path: Path,
request: pytest.FixtureRequest,
) -> None:
) -> MockProvider | MockProviderEntity:
"""Set up the test environment."""
if request.param == "mock_setup":
await mock_setup(hass, tmp_path, MockProvider())
provider = MockProvider()
await mock_setup(hass, tmp_path, provider)
elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
provider = MockProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider)
else:
raise RuntimeError("Invalid setup fixture")
return provider
async def mock_setup(
hass: HomeAssistant,
@ -206,11 +215,11 @@ async def mock_config_entry_setup(
async def test_get_provider_info(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
setup: MockProvider | MockProviderEntity,
) -> None:
"""Test engine that doesn't exist."""
client = await hass_client()
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
response = await client.get(f"/api/stt/{setup.url_path}")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["de", "de-CH", "en-US"],
@ -228,7 +237,7 @@ async def test_get_provider_info(
async def test_non_existing_provider(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
setup: MockProvider | MockProviderEntity,
) -> None:
"""Test streaming to engine that doesn't exist."""
client = await hass_client()
@ -255,14 +264,14 @@ async def test_non_existing_provider(
async def test_stream_audio(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
setup: MockProvider | MockProviderEntity,
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
# Language en is matched with en-US
response = await client.post(
f"/api/stt/{TEST_DOMAIN}",
f"/api/stt/{setup.url_path}",
headers={
"X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
@ -318,7 +327,7 @@ async def test_metadata_errors(
header: str | None,
status: int,
error: str,
setup: str,
setup: MockProvider | MockProviderEntity,
) -> None:
"""Test metadata errors."""
client = await hass_client()
@ -326,7 +335,7 @@ async def test_metadata_errors(
if header:
headers["X-Speech-Content"] = header
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
response = await client.post(f"/api/stt/{setup.url_path}", headers=headers)
assert response.status == status
assert await response.text() == error
@ -351,16 +360,6 @@ async def test_config_entry_unload(
assert config_entry.state == ConfigEntryState.NOT_LOADED
def test_entity_name_raises_before_addition(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test entity name raises before addition to Home Assistant."""
with pytest.raises(RuntimeError):
mock_provider_entity.name # pylint: disable=pointless-statement
async def test_restore_state(
hass: HomeAssistant,
tmp_path: Path,
@ -388,7 +387,7 @@ async def test_restore_state(
async def test_ws_list_engines(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
setup: str,
setup: MockProvider | MockProviderEntity,
engine_id: str,
) -> None:
"""Test listing speech to text engines."""

View File

@ -13,10 +13,10 @@ from . import MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test supported properties."""
state = hass.states.get("stt.wyoming")
state = hass.states.get("stt.test_asr")
assert state is not None
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
assert entity.supported_languages == ["en-US"]
@ -29,7 +29,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
"""Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
async def audio_stream():
@ -51,7 +51,7 @@ async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt
) -> None:
"""Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
async def audio_stream():
@ -69,7 +69,7 @@ async def test_streaming_audio_connection_lost(
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
async def audio_stream():