Allow entity names for STT entities (#91932)
* Allow entity names for STT entities * Fix testspull/91940/head
parent
fba7c6cacd
commit
a203149133
|
@ -304,7 +304,10 @@ class PipelineRun:
|
|||
if self.stt_provider is None:
|
||||
raise RuntimeError("Speech to text was not prepared")
|
||||
|
||||
engine = self.stt_provider.name
|
||||
if isinstance(self.stt_provider, stt.Provider):
|
||||
engine = self.stt_provider.name
|
||||
else:
|
||||
engine = self.stt_provider.entity_id
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
|
|
|
@ -44,6 +44,8 @@ async def async_setup_entry(
|
|||
class DemoProviderEntity(SpeechToTextEntity):
|
||||
"""Demo speech API provider entity."""
|
||||
|
||||
_attr_name = "Demo STT"
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
|
|
|
@ -128,16 +128,6 @@ class SpeechToTextEntity(RestoreEntity):
|
|||
_attr_should_poll = False
|
||||
__last_processed: str | None = None
|
||||
|
||||
@property
|
||||
@final
|
||||
def name(self) -> str:
|
||||
"""Return the name of the provider entity."""
|
||||
# Only one entity is allowed per platform for now.
|
||||
if self.platform is None:
|
||||
raise RuntimeError("Entity is not added to hass yet.")
|
||||
|
||||
return self.platform.platform_name
|
||||
|
||||
@property
|
||||
@final
|
||||
def state(self) -> str | None:
|
||||
|
@ -249,11 +239,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||
hass: HomeAssistant = request.app["hass"]
|
||||
provider_entity: SpeechToTextEntity | None = None
|
||||
if (
|
||||
not (
|
||||
provider_entity := async_get_speech_to_text_entity(
|
||||
hass, f"{DOMAIN}.{provider}"
|
||||
)
|
||||
)
|
||||
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
|
||||
and provider not in self.providers
|
||||
):
|
||||
raise HTTPNotFound()
|
||||
|
@ -292,11 +278,7 @@ class SpeechToTextView(HomeAssistantView):
|
|||
"""Return provider specific audio information."""
|
||||
hass: HomeAssistant = request.app["hass"]
|
||||
if (
|
||||
not (
|
||||
provider_entity := async_get_speech_to_text_entity(
|
||||
hass, f"{DOMAIN}.{provider}"
|
||||
)
|
||||
)
|
||||
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
|
||||
and provider not in self.providers
|
||||
):
|
||||
raise HTTPNotFound()
|
||||
|
|
|
@ -89,6 +89,8 @@ class MockSttProvider(BaseProvider, stt.Provider):
|
|||
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
_attr_name = "Mock STT"
|
||||
|
||||
|
||||
class MockTTSProvider(tts.Provider):
|
||||
"""Mock TTS provider."""
|
||||
|
|
|
@ -97,7 +97,7 @@
|
|||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'engine': 'stt.mock_stt',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
"""The tests for the demo stt component."""
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
@ -24,7 +26,11 @@ async def setup_config_entry(hass: HomeAssistant) -> None:
|
|||
"""Set up demo component from config entry."""
|
||||
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
|
||||
config_entry.add_to_hass(hass)
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
with patch(
|
||||
"homeassistant.components.demo.COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM",
|
||||
[Platform.STT],
|
||||
):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
|
@ -103,7 +109,7 @@ async def test_config_entry_demo_speech(
|
|||
client = await hass_client()
|
||||
|
||||
response = await client.post(
|
||||
"/api/stt/demo",
|
||||
"/api/stt/stt.demo_stt",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
|
||||
|
|
|
@ -93,10 +93,15 @@ class BaseProvider:
|
|||
class MockProvider(BaseProvider, Provider):
|
||||
"""Mock provider."""
|
||||
|
||||
url_path = TEST_DOMAIN
|
||||
|
||||
|
||||
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
||||
"""Mock provider entity."""
|
||||
|
||||
url_path = "stt.test"
|
||||
_attr_name = "test"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider() -> MockProvider:
|
||||
|
@ -128,15 +133,19 @@ async def setup_fixture(
|
|||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
request: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
) -> MockProvider | MockProviderEntity:
|
||||
"""Set up the test environment."""
|
||||
if request.param == "mock_setup":
|
||||
await mock_setup(hass, tmp_path, MockProvider())
|
||||
provider = MockProvider()
|
||||
await mock_setup(hass, tmp_path, provider)
|
||||
elif request.param == "mock_config_entry_setup":
|
||||
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
|
||||
provider = MockProviderEntity()
|
||||
await mock_config_entry_setup(hass, tmp_path, provider)
|
||||
else:
|
||||
raise RuntimeError("Invalid setup fixture")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
async def mock_setup(
|
||||
hass: HomeAssistant,
|
||||
|
@ -206,11 +215,11 @@ async def mock_config_entry_setup(
|
|||
async def test_get_provider_info(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
setup: MockProvider | MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
|
||||
response = await client.get(f"/api/stt/{setup.url_path}")
|
||||
assert response.status == HTTPStatus.OK
|
||||
assert await response.json() == {
|
||||
"languages": ["de", "de-CH", "en-US"],
|
||||
|
@ -228,7 +237,7 @@ async def test_get_provider_info(
|
|||
async def test_non_existing_provider(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
setup: MockProvider | MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test streaming to engine that doesn't exist."""
|
||||
client = await hass_client()
|
||||
|
@ -255,14 +264,14 @@ async def test_non_existing_provider(
|
|||
async def test_stream_audio(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
setup: str,
|
||||
setup: MockProvider | MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test streaming audio and getting response."""
|
||||
client = await hass_client()
|
||||
|
||||
# Language en is matched with en-US
|
||||
response = await client.post(
|
||||
f"/api/stt/{TEST_DOMAIN}",
|
||||
f"/api/stt/{setup.url_path}",
|
||||
headers={
|
||||
"X-Speech-Content": (
|
||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||
|
@ -318,7 +327,7 @@ async def test_metadata_errors(
|
|||
header: str | None,
|
||||
status: int,
|
||||
error: str,
|
||||
setup: str,
|
||||
setup: MockProvider | MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test metadata errors."""
|
||||
client = await hass_client()
|
||||
|
@ -326,7 +335,7 @@ async def test_metadata_errors(
|
|||
if header:
|
||||
headers["X-Speech-Content"] = header
|
||||
|
||||
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
|
||||
response = await client.post(f"/api/stt/{setup.url_path}", headers=headers)
|
||||
assert response.status == status
|
||||
assert await response.text() == error
|
||||
|
||||
|
@ -351,16 +360,6 @@ async def test_config_entry_unload(
|
|||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||
|
||||
|
||||
def test_entity_name_raises_before_addition(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
mock_provider_entity: MockProviderEntity,
|
||||
) -> None:
|
||||
"""Test entity name raises before addition to Home Assistant."""
|
||||
with pytest.raises(RuntimeError):
|
||||
mock_provider_entity.name # pylint: disable=pointless-statement
|
||||
|
||||
|
||||
async def test_restore_state(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
|
@ -388,7 +387,7 @@ async def test_restore_state(
|
|||
async def test_ws_list_engines(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
setup: str,
|
||||
setup: MockProvider | MockProviderEntity,
|
||||
engine_id: str,
|
||||
) -> None:
|
||||
"""Test listing speech to text engines."""
|
||||
|
|
|
@ -13,10 +13,10 @@ from . import MockAsyncTcpClient
|
|||
|
||||
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test supported properties."""
|
||||
state = hass.states.get("stt.wyoming")
|
||||
state = hass.states.get("stt.test_asr")
|
||||
assert state is not None
|
||||
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
||||
assert entity.supported_languages == ["en-US"]
|
||||
|
@ -29,7 +29,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
|||
|
||||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||
"""Test streaming audio."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
|
@ -51,7 +51,7 @@ async def test_streaming_audio_connection_lost(
|
|||
hass: HomeAssistant, init_wyoming_stt
|
||||
) -> None:
|
||||
"""Test streaming audio and losing connection."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
|
@ -69,7 +69,7 @@ async def test_streaming_audio_connection_lost(
|
|||
|
||||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||
"""Test streaming audio and error raising."""
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||
assert entity is not None
|
||||
|
||||
async def audio_stream():
|
||||
|
|
Loading…
Reference in New Issue