Allow entity names for STT entities (#91932)
* Allow entity names for STT entities * Fix testspull/91940/head
parent
fba7c6cacd
commit
a203149133
|
@ -304,7 +304,10 @@ class PipelineRun:
|
||||||
if self.stt_provider is None:
|
if self.stt_provider is None:
|
||||||
raise RuntimeError("Speech to text was not prepared")
|
raise RuntimeError("Speech to text was not prepared")
|
||||||
|
|
||||||
engine = self.stt_provider.name
|
if isinstance(self.stt_provider, stt.Provider):
|
||||||
|
engine = self.stt_provider.name
|
||||||
|
else:
|
||||||
|
engine = self.stt_provider.entity_id
|
||||||
|
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
|
|
|
@ -44,6 +44,8 @@ async def async_setup_entry(
|
||||||
class DemoProviderEntity(SpeechToTextEntity):
|
class DemoProviderEntity(SpeechToTextEntity):
|
||||||
"""Demo speech API provider entity."""
|
"""Demo speech API provider entity."""
|
||||||
|
|
||||||
|
_attr_name = "Demo STT"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
|
|
|
@ -128,16 +128,6 @@ class SpeechToTextEntity(RestoreEntity):
|
||||||
_attr_should_poll = False
|
_attr_should_poll = False
|
||||||
__last_processed: str | None = None
|
__last_processed: str | None = None
|
||||||
|
|
||||||
@property
|
|
||||||
@final
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Return the name of the provider entity."""
|
|
||||||
# Only one entity is allowed per platform for now.
|
|
||||||
if self.platform is None:
|
|
||||||
raise RuntimeError("Entity is not added to hass yet.")
|
|
||||||
|
|
||||||
return self.platform.platform_name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@final
|
@final
|
||||||
def state(self) -> str | None:
|
def state(self) -> str | None:
|
||||||
|
@ -249,11 +239,7 @@ class SpeechToTextView(HomeAssistantView):
|
||||||
hass: HomeAssistant = request.app["hass"]
|
hass: HomeAssistant = request.app["hass"]
|
||||||
provider_entity: SpeechToTextEntity | None = None
|
provider_entity: SpeechToTextEntity | None = None
|
||||||
if (
|
if (
|
||||||
not (
|
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
|
||||||
provider_entity := async_get_speech_to_text_entity(
|
|
||||||
hass, f"{DOMAIN}.{provider}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
and provider not in self.providers
|
and provider not in self.providers
|
||||||
):
|
):
|
||||||
raise HTTPNotFound()
|
raise HTTPNotFound()
|
||||||
|
@ -292,11 +278,7 @@ class SpeechToTextView(HomeAssistantView):
|
||||||
"""Return provider specific audio information."""
|
"""Return provider specific audio information."""
|
||||||
hass: HomeAssistant = request.app["hass"]
|
hass: HomeAssistant = request.app["hass"]
|
||||||
if (
|
if (
|
||||||
not (
|
not (provider_entity := async_get_speech_to_text_entity(hass, provider))
|
||||||
provider_entity := async_get_speech_to_text_entity(
|
|
||||||
hass, f"{DOMAIN}.{provider}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
and provider not in self.providers
|
and provider not in self.providers
|
||||||
):
|
):
|
||||||
raise HTTPNotFound()
|
raise HTTPNotFound()
|
||||||
|
|
|
@ -89,6 +89,8 @@ class MockSttProvider(BaseProvider, stt.Provider):
|
||||||
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
class MockSttProviderEntity(BaseProvider, stt.SpeechToTextEntity):
|
||||||
"""Mock provider entity."""
|
"""Mock provider entity."""
|
||||||
|
|
||||||
|
_attr_name = "Mock STT"
|
||||||
|
|
||||||
|
|
||||||
class MockTTSProvider(tts.Provider):
|
class MockTTSProvider(tts.Provider):
|
||||||
"""Mock TTS provider."""
|
"""Mock TTS provider."""
|
||||||
|
|
|
@ -97,7 +97,7 @@
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'test',
|
'engine': 'stt.mock_stt',
|
||||||
'metadata': dict({
|
'metadata': dict({
|
||||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
"""The tests for the demo stt component."""
|
"""The tests for the demo stt component."""
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt
|
from homeassistant.components import stt
|
||||||
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
|
from homeassistant.components.demo import DOMAIN as DEMO_DOMAIN
|
||||||
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
@ -24,7 +26,11 @@ async def setup_config_entry(hass: HomeAssistant) -> None:
|
||||||
"""Set up demo component from config entry."""
|
"""Set up demo component from config entry."""
|
||||||
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
|
config_entry = MockConfigEntry(domain=DEMO_DOMAIN)
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
with patch(
|
||||||
|
"homeassistant.components.demo.COMPONENTS_WITH_CONFIG_ENTRY_DEMO_PLATFORM",
|
||||||
|
[Platform.STT],
|
||||||
|
):
|
||||||
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,7 +109,7 @@ async def test_config_entry_demo_speech(
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/stt/demo",
|
"/api/stt/stt.demo_stt",
|
||||||
headers={
|
headers={
|
||||||
"X-Speech-Content": (
|
"X-Speech-Content": (
|
||||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=2;"
|
||||||
|
|
|
@ -93,10 +93,15 @@ class BaseProvider:
|
||||||
class MockProvider(BaseProvider, Provider):
|
class MockProvider(BaseProvider, Provider):
|
||||||
"""Mock provider."""
|
"""Mock provider."""
|
||||||
|
|
||||||
|
url_path = TEST_DOMAIN
|
||||||
|
|
||||||
|
|
||||||
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
class MockProviderEntity(BaseProvider, SpeechToTextEntity):
|
||||||
"""Mock provider entity."""
|
"""Mock provider entity."""
|
||||||
|
|
||||||
|
url_path = "stt.test"
|
||||||
|
_attr_name = "test"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_provider() -> MockProvider:
|
def mock_provider() -> MockProvider:
|
||||||
|
@ -128,15 +133,19 @@ async def setup_fixture(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
request: pytest.FixtureRequest,
|
request: pytest.FixtureRequest,
|
||||||
) -> None:
|
) -> MockProvider | MockProviderEntity:
|
||||||
"""Set up the test environment."""
|
"""Set up the test environment."""
|
||||||
if request.param == "mock_setup":
|
if request.param == "mock_setup":
|
||||||
await mock_setup(hass, tmp_path, MockProvider())
|
provider = MockProvider()
|
||||||
|
await mock_setup(hass, tmp_path, provider)
|
||||||
elif request.param == "mock_config_entry_setup":
|
elif request.param == "mock_config_entry_setup":
|
||||||
await mock_config_entry_setup(hass, tmp_path, MockProviderEntity())
|
provider = MockProviderEntity()
|
||||||
|
await mock_config_entry_setup(hass, tmp_path, provider)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Invalid setup fixture")
|
raise RuntimeError("Invalid setup fixture")
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
async def mock_setup(
|
async def mock_setup(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -206,11 +215,11 @@ async def mock_config_entry_setup(
|
||||||
async def test_get_provider_info(
|
async def test_get_provider_info(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: str,
|
setup: MockProvider | MockProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test engine that doesn't exist."""
|
"""Test engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
|
response = await client.get(f"/api/stt/{setup.url_path}")
|
||||||
assert response.status == HTTPStatus.OK
|
assert response.status == HTTPStatus.OK
|
||||||
assert await response.json() == {
|
assert await response.json() == {
|
||||||
"languages": ["de", "de-CH", "en-US"],
|
"languages": ["de", "de-CH", "en-US"],
|
||||||
|
@ -228,7 +237,7 @@ async def test_get_provider_info(
|
||||||
async def test_non_existing_provider(
|
async def test_non_existing_provider(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: str,
|
setup: MockProvider | MockProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming to engine that doesn't exist."""
|
"""Test streaming to engine that doesn't exist."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
@ -255,14 +264,14 @@ async def test_non_existing_provider(
|
||||||
async def test_stream_audio(
|
async def test_stream_audio(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_client: ClientSessionGenerator,
|
hass_client: ClientSessionGenerator,
|
||||||
setup: str,
|
setup: MockProvider | MockProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and getting response."""
|
"""Test streaming audio and getting response."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
|
||||||
# Language en is matched with en-US
|
# Language en is matched with en-US
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"/api/stt/{TEST_DOMAIN}",
|
f"/api/stt/{setup.url_path}",
|
||||||
headers={
|
headers={
|
||||||
"X-Speech-Content": (
|
"X-Speech-Content": (
|
||||||
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
"format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1;"
|
||||||
|
@ -318,7 +327,7 @@ async def test_metadata_errors(
|
||||||
header: str | None,
|
header: str | None,
|
||||||
status: int,
|
status: int,
|
||||||
error: str,
|
error: str,
|
||||||
setup: str,
|
setup: MockProvider | MockProviderEntity,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test metadata errors."""
|
"""Test metadata errors."""
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
|
@ -326,7 +335,7 @@ async def test_metadata_errors(
|
||||||
if header:
|
if header:
|
||||||
headers["X-Speech-Content"] = header
|
headers["X-Speech-Content"] = header
|
||||||
|
|
||||||
response = await client.post(f"/api/stt/{TEST_DOMAIN}", headers=headers)
|
response = await client.post(f"/api/stt/{setup.url_path}", headers=headers)
|
||||||
assert response.status == status
|
assert response.status == status
|
||||||
assert await response.text() == error
|
assert await response.text() == error
|
||||||
|
|
||||||
|
@ -351,16 +360,6 @@ async def test_config_entry_unload(
|
||||||
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
assert config_entry.state == ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
|
||||||
def test_entity_name_raises_before_addition(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
tmp_path: Path,
|
|
||||||
mock_provider_entity: MockProviderEntity,
|
|
||||||
) -> None:
|
|
||||||
"""Test entity name raises before addition to Home Assistant."""
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
mock_provider_entity.name # pylint: disable=pointless-statement
|
|
||||||
|
|
||||||
|
|
||||||
async def test_restore_state(
|
async def test_restore_state(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
|
@ -388,7 +387,7 @@ async def test_restore_state(
|
||||||
async def test_ws_list_engines(
|
async def test_ws_list_engines(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
setup: str,
|
setup: MockProvider | MockProviderEntity,
|
||||||
engine_id: str,
|
engine_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test listing speech to text engines."""
|
"""Test listing speech to text engines."""
|
||||||
|
|
|
@ -13,10 +13,10 @@ from . import MockAsyncTcpClient
|
||||||
|
|
||||||
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||||
"""Test supported properties."""
|
"""Test supported properties."""
|
||||||
state = hass.states.get("stt.wyoming")
|
state = hass.states.get("stt.test_asr")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
|
|
||||||
assert entity.supported_languages == ["en-US"]
|
assert entity.supported_languages == ["en-US"]
|
||||||
|
@ -29,7 +29,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||||
|
|
||||||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||||
"""Test streaming audio."""
|
"""Test streaming audio."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
|
@ -51,7 +51,7 @@ async def test_streaming_audio_connection_lost(
|
||||||
hass: HomeAssistant, init_wyoming_stt
|
hass: HomeAssistant, init_wyoming_stt
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and losing connection."""
|
"""Test streaming audio and losing connection."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
|
@ -69,7 +69,7 @@ async def test_streaming_audio_connection_lost(
|
||||||
|
|
||||||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||||
"""Test streaming audio and error raising."""
|
"""Test streaming audio and error raising."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
|
||||||
assert entity is not None
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
|
|
Loading…
Reference in New Issue