Allow entity names for STT entities (#91932)

* Allow entity names for STT entities

* Fix tests
pull/91940/head
Paulus Schoutsen 2023-04-23 23:06:34 -04:00 committed by GitHub
parent fba7c6cacd
commit a203149133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 50 deletions

View File

@ -304,7 +304,10 @@ class PipelineRun:
if self.stt_provider is None: if self.stt_provider is None:
raise RuntimeError("Speech to text was not prepared") raise RuntimeError("Speech to text was not prepared")
engine = self.stt_provider.name if isinstance(self.stt_provider, stt.Provider):
engine = self.stt_provider.name
else:
engine = self.stt_provider.entity_id
self.process_event( self.process_event(
PipelineEvent( PipelineEvent(

View File

@ -44,6 +44,8 @@ async def async_setup_entry(
class DemoProviderEntity(SpeechToTextEntity): class DemoProviderEntity(SpeechToTextEntity):
"""Demo speech API provider entity.""" """Demo speech API provider entity."""
_attr_name = "Demo STT"
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""

View File

@ -128,16 +128,6 @@ class SpeechToTextEntity(RestoreEntity):
_attr_should_poll = False _attr_should_poll = False
__last_processed: str | None = None __last_processed: str | None = None
@property
@final
def name(self) -> str:
"""Return the name of the provider entity."""
# Only one entity is allowed per platform for now.
if self.platform is None:
raise RuntimeError("Entity is not added to hass yet.")
return self.platform.platform_name
@property @property
@final @final
def state(self) -> str | None: def state(self) -> str | None:
@ -249,11 +239,7 @@ class SpeechToTextView(HomeAssistantView):
hass: HomeAssistant = request.app["hass"] hass: HomeAssistant = request.app["hass"]
provider_entity: SpeechToTextEntity | None = None provider_entity: SpeechToTextEntity | None = None
if ( if (
not ( not (provider_entity := async_get_speech_to_text_entity(hass, provider))
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
and provider not in self.providers and provider not in self.providers
): ):
raise HTTPNotFound() raise HTTPNotFound()
@ -292,11 +278,7 @@ class SpeechToTextView(HomeAssistantView):
"""Return provider specific audio information.""" """Return provider specific audio information."""
hass: HomeAssistant = request.app["hass"] hass: HomeAssistant = request.app["hass"]
if ( if (
not ( not (provider_entity := async_get_speech_to_text_entity(hass, provider))
provider_entity := async_get_speech_to_text_entity(
hass, f"{DOMAIN}.{provider}"
)
)
and provider not in self.providers and provider not in self.providers
): ):
raise HTTPNotFound() raise HTTPNotFound()

View File

@ -89,6 +89,8 @@ class MockSttProvider(BaseProvider, stt.Provider):
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity): class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
"""Mock provider entity.""" """Mock provider entity."""
_attr_name = "Mock STT"
class MockTTSProvider(tts.Provider): class MockTTSProvider(tts.Provider):
"""Mock TTS provider.""" """Mock TTS provider."""

View File

@ -97,7 +97,7 @@
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'engine': 'test', 'engine': 'stt.mock_stt',
'metadata': dict({ 'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>, 'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>, 'channel': <AudioChannels.CHANNEL_MONO: 1>,

View File

@ -1,10 +1,12 @@
"""The tests for the demo stt component.""" """The tests for the demo stt component."""
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import patch
import pytest import pytest
from homeassistant.components import stt from homeassistant.components import stt
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -24,7 +26,11 @@ async def setup_config_entry(hass: HomeAssistant) -> None:
"""Set up demo component from config entry.""" """Set up demo component from config entry."""
config_entry = MockConfigEntry(domain=DEMO_DOMAIN) config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) with patch(
"homeassistant.components.demo.COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM",
[Platform.STT],
):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -103,7 +109,7 @@ async def test_config_entry_demo_speech(
client = await hass_client() client = await hass_client()
response = await client.post( response = await client.post(
"/api/stt/demo", "/api/stt/stt.demo_stt",
headers={ headers={
"X-Speech-Content": ( "X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;" "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"

View File

@ -93,10 +93,15 @@ class BaseProvider:
class MockProvider(BaseProvider, Provider): class MockProvider(BaseProvider, Provider):
"""Mock provider.""" """Mock provider."""
url_path = TEST_DOMAIN
class MockProviderEntity(BaseProvider, SpeechToTextEntity): class MockProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity.""" """Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
@pytest.fixture @pytest.fixture
def mock_provider() -> MockProvider: def mock_provider() -> MockProvider:
@ -128,15 +133,19 @@ async def setup_fixture(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
request: pytest.FixtureRequest, request: pytest.FixtureRequest,
) -> None: ) -> MockProvider | MockProviderEntity:
"""Set up the test environment.""" """Set up the test environment."""
if request.param == "mock_setup": if request.param == "mock_setup":
await mock_setup(hass, tmp_path, MockProvider()) provider = MockProvider()
await mock_setup(hass, tmp_path, provider)
elif request.param == "mock_config_entry_setup": elif request.param == "mock_config_entry_setup":
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity()) provider = MockProviderEntity()
await mock_config_entry_setup(hass, tmp_path, provider)
else: else:
raise RuntimeError("Invalid setup fixture") raise RuntimeError("Invalid setup fixture")
return provider
async def mock_setup( async def mock_setup(
hass: HomeAssistant, hass: HomeAssistant,
@ -206,11 +215,11 @@ async def mock_config_entry_setup(
async def test_get_provider_info( async def test_get_provider_info(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: str, setup: MockProvider | MockProviderEntity,
) -> None: ) -> None:
"""Test engine that doesn't exist.""" """Test engine that doesn't exist."""
client = await hass_client() client = await hass_client()
response = await client.get(f"/api/stt/{TEST_DOMAIN}") response = await client.get(f"/api/stt/{setup.url_path}")
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
assert await response.json() == { assert await response.json() == {
"languages": ["de", "de-CH", "en-US"], "languages": ["de", "de-CH", "en-US"],
@ -228,7 +237,7 @@ async def test_get_provider_info(
async def test_non_existing_provider( async def test_non_existing_provider(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: str, setup: MockProvider | MockProviderEntity,
) -> None: ) -> None:
"""Test streaming to engine that doesn't exist.""" """Test streaming to engine that doesn't exist."""
client = await hass_client() client = await hass_client()
@ -255,14 +264,14 @@ async def test_non_existing_provider(
async def test_stream_audio( async def test_stream_audio(
hass: HomeAssistant, hass: HomeAssistant,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
setup: str, setup: MockProvider | MockProviderEntity,
) -> None: ) -> None:
"""Test streaming audio and getting response.""" """Test streaming audio and getting response."""
client = await hass_client() client = await hass_client()
# Language en is matched with en-US # Language en is matched with en-US
response = await client.post( response = await client.post(
f"/api/stt/{TEST_DOMAIN}", f"/api/stt/{setup.url_path}",
headers={ headers={
"X-Speech-Content": ( "X-Speech-Content": (
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;" "format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
@ -318,7 +327,7 @@ async def test_metadata_errors(
header: str | None, header: str | None,
status: int, status: int,
error: str, error: str,
setup: str, setup: MockProvider | MockProviderEntity,
) -> None: ) -> None:
"""Test metadata errors.""" """Test metadata errors."""
client = await hass_client() client = await hass_client()
@ -326,7 +335,7 @@ async def test_metadata_errors(
if header: if header:
headers["X-Speech-Content"] = header headers["X-Speech-Content"] = header
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers) response = await client.post(f"/api/stt/{setup.url_path}", headers=headers)
assert response.status == status assert response.status == status
assert await response.text() == error assert await response.text() == error
@ -351,16 +360,6 @@ async def test_config_entry_unload(
assert config_entry.state == ConfigEntryState.NOT_LOADED assert config_entry.state == ConfigEntryState.NOT_LOADED
def test_entity_name_raises_before_addition(
hass: HomeAssistant,
tmp_path: Path,
mock_provider_entity: MockProviderEntity,
) -> None:
"""Test entity name raises before addition to Home Assistant."""
with pytest.raises(RuntimeError):
mock_provider_entity.name # pylint: disable=pointless-statement
async def test_restore_state( async def test_restore_state(
hass: HomeAssistant, hass: HomeAssistant,
tmp_path: Path, tmp_path: Path,
@ -388,7 +387,7 @@ async def test_restore_state(
async def test_ws_list_engines( async def test_ws_list_engines(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
setup: str, setup: MockProvider | MockProviderEntity,
engine_id: str, engine_id: str,
) -> None: ) -> None:
"""Test listing speech to text engines.""" """Test listing speech to text engines."""

View File

@ -13,10 +13,10 @@ from . import MockAsyncTcpClient
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test supported properties.""" """Test supported properties."""
state = hass.states.get("stt.wyoming") state = hass.states.get("stt.test_asr")
assert state is not None assert state is not None
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
assert entity.supported_languages == ["en-US"] assert entity.supported_languages == ["en-US"]
@ -29,7 +29,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
"""Test streaming audio.""" """Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
async def audio_stream(): async def audio_stream():
@ -51,7 +51,7 @@ async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt hass: HomeAssistant, init_wyoming_stt
) -> None: ) -> None:
"""Test streaming audio and losing connection.""" """Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
async def audio_stream(): async def audio_stream():
@ -69,7 +69,7 @@ async def test_streaming_audio_connection_lost(
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
"""Test streaming audio and error raising.""" """Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
async def audio_stream(): async def audio_stream():