Cleanup language support on TTS (#5255)
* Cleanup language support on TTS * change to default_language & address comments * Cleanup not needed code / comment from pauluspull/3734/merge
parent
467cb18625
commit
3f3a3bcc8a
|
@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at
|
|||
https://home-assistant.io/components/tts/
|
||||
"""
|
||||
import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
|
@ -247,8 +246,6 @@ class SpeechManager(object):
|
|||
def async_register_engine(self, engine, provider, config):
|
||||
"""Register a TTS provider."""
|
||||
provider.hass = self.hass
|
||||
if CONF_LANG in config:
|
||||
provider.language = config.get(CONF_LANG)
|
||||
self.providers[engine] = provider
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -257,9 +254,16 @@ class SpeechManager(object):
|
|||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
provider = self.providers[engine]
|
||||
|
||||
language = language or provider.default_language
|
||||
if language is None or \
|
||||
language not in provider.supported_languages:
|
||||
raise HomeAssistantError("Not supported language {0}".format(
|
||||
language))
|
||||
|
||||
msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest()
|
||||
language_key = language or self.providers[engine].language
|
||||
key = KEY_PATTERN.format(msg_hash, language_key, engine).lower()
|
||||
key = KEY_PATTERN.format(msg_hash, language, engine).lower()
|
||||
use_cache = cache if cache is not None else self.use_cache
|
||||
|
||||
# is speech allready in memory
|
||||
|
@ -387,13 +391,22 @@ class Provider(object):
|
|||
"""Represent a single provider."""
|
||||
|
||||
hass = None
|
||||
language = None
|
||||
|
||||
def get_tts_audio(self, message, language=None):
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Default language."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""List of supported languages."""
|
||||
return None
|
||||
|
||||
def get_tts_audio(self, message, language):
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def async_get_tts_audio(self, message, language=None):
|
||||
def async_get_tts_audio(self, message, language):
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
|
@ -401,8 +414,7 @@ class Provider(object):
|
|||
This method must be run in the event loop and returns a coroutine.
|
||||
"""
|
||||
return self.hass.loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(self.get_tts_audio, message, language=language))
|
||||
None, self.get_tts_audio, message, language)
|
||||
|
||||
|
||||
class TextToSpeechView(HomeAssistantView):
|
||||
|
|
|
@ -6,28 +6,50 @@ https://home-assistant.io/components/demo/
|
|||
"""
|
||||
import os
|
||||
|
||||
from homeassistant.components.tts import Provider
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.tts import Provider, PLATFORM_SCHEMA, CONF_LANG
|
||||
|
||||
SUPPORT_LANGUAGES = [
|
||||
'en', 'de'
|
||||
]
|
||||
|
||||
DEFAULT_LANG = 'en'
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES),
|
||||
})
|
||||
|
||||
|
||||
def get_engine(hass, config):
|
||||
"""Setup Demo speech component."""
|
||||
return DemoProvider()
|
||||
return DemoProvider(config[CONF_LANG])
|
||||
|
||||
|
||||
class DemoProvider(Provider):
|
||||
"""Demo speech api provider."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize demo provider for TTS."""
|
||||
self.language = 'en'
|
||||
def __init__(self, lang):
|
||||
"""Initialize demo provider."""
|
||||
self._lang = lang
|
||||
|
||||
def get_tts_audio(self, message, language=None):
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""List of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
def get_tts_audio(self, message, language):
|
||||
"""Load TTS from demo."""
|
||||
filename = os.path.join(os.path.dirname(__file__), "demo.mp3")
|
||||
try:
|
||||
with open(filename, 'rb') as voice:
|
||||
data = voice.read()
|
||||
except OSError:
|
||||
return
|
||||
return (None, None)
|
||||
|
||||
return ("mp3", data)
|
||||
|
|
|
@ -42,15 +42,16 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
|||
@asyncio.coroutine
|
||||
def async_get_engine(hass, config):
|
||||
"""Setup Google speech component."""
|
||||
return GoogleProvider(hass)
|
||||
return GoogleProvider(hass, config[CONF_LANG])
|
||||
|
||||
|
||||
class GoogleProvider(Provider):
|
||||
"""Google speech api provider."""
|
||||
|
||||
def __init__(self, hass):
|
||||
def __init__(self, hass, lang):
|
||||
"""Init Google TTS service."""
|
||||
self.hass = hass
|
||||
self._lang = lang
|
||||
self.headers = {
|
||||
'Referer': "http://translate.google.com/",
|
||||
'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) "
|
||||
|
@ -58,8 +59,18 @@ class GoogleProvider(Provider):
|
|||
"Chrome/47.0.2526.106 Safari/537.36")
|
||||
}
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""List of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, message, language=None):
|
||||
def async_get_tts_audio(self, message, language):
|
||||
"""Load TTS from google."""
|
||||
from gtts_token import gtts_token
|
||||
|
||||
|
@ -67,11 +78,6 @@ class GoogleProvider(Provider):
|
|||
websession = async_get_clientsession(self.hass)
|
||||
message_parts = self._split_message_to_parts(message)
|
||||
|
||||
# If language is not specified or is not supported - use the language
|
||||
# from the config.
|
||||
if language not in SUPPORT_LANGUAGES:
|
||||
language = self.language
|
||||
|
||||
data = b''
|
||||
for idx, part in enumerate(message_parts):
|
||||
part_token = yield from self.hass.loop.run_in_executor(
|
||||
|
|
|
@ -29,18 +29,31 @@ def get_engine(hass, config):
|
|||
if shutil.which("pico2wave") is None:
|
||||
_LOGGER.error("'pico2wave' was not found")
|
||||
return False
|
||||
return PicoProvider()
|
||||
return PicoProvider(config[CONF_LANG])
|
||||
|
||||
|
||||
class PicoProvider(Provider):
|
||||
"""pico speech api provider."""
|
||||
|
||||
def get_tts_audio(self, message, language=None):
|
||||
def __init__(self, lang):
|
||||
"""Initialize pico provider."""
|
||||
self._lang = lang
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""List of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
def get_tts_audio(self, message, language):
|
||||
"""Load TTS using pico2wave."""
|
||||
if language not in SUPPORT_LANGUAGES:
|
||||
language = self.language
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpf:
|
||||
fname = tmpf.name
|
||||
|
||||
cmd = ['pico2wave', '--wave', fname, '-l', language, message]
|
||||
subprocess.call(cmd)
|
||||
data = None
|
||||
|
@ -52,6 +65,7 @@ class PicoProvider(Provider):
|
|||
return (None, None)
|
||||
finally:
|
||||
os.remove(fname)
|
||||
|
||||
if data:
|
||||
return ("wav", data)
|
||||
return (None, None)
|
||||
|
|
|
@ -93,27 +93,34 @@ class VoiceRSSProvider(Provider):
|
|||
def __init__(self, hass, conf):
|
||||
"""Init VoiceRSS TTS service."""
|
||||
self.hass = hass
|
||||
self.extension = conf.get(CONF_CODEC)
|
||||
self._extension = conf[CONF_CODEC]
|
||||
self._lang = conf[CONF_LANG]
|
||||
|
||||
self.form_data = {
|
||||
'key': conf.get(CONF_API_KEY),
|
||||
'hl': conf.get(CONF_LANG),
|
||||
'c': (conf.get(CONF_CODEC)).upper(),
|
||||
'f': conf.get(CONF_FORMAT),
|
||||
self._form_data = {
|
||||
'key': conf[CONF_API_KEY],
|
||||
'hl': conf[CONF_LANG],
|
||||
'c': (conf[CONF_CODEC]).upper(),
|
||||
'f': conf[CONF_FORMAT],
|
||||
}
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
"""Default language."""
|
||||
return self._lang
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
"""List of supported languages."""
|
||||
return SUPPORT_LANGUAGES
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_get_tts_audio(self, message, language=None):
|
||||
def async_get_tts_audio(self, message, language):
|
||||
"""Load TTS from voicerss."""
|
||||
websession = async_get_clientsession(self.hass)
|
||||
form_data = self.form_data.copy()
|
||||
form_data = self._form_data.copy()
|
||||
|
||||
form_data['src'] = message
|
||||
|
||||
# If language is specified and supported - use it instead of the
|
||||
# language in the config.
|
||||
if language in SUPPORT_LANGUAGES:
|
||||
form_data['hl'] = language
|
||||
form_data['hl'] = language
|
||||
|
||||
request = None
|
||||
try:
|
||||
|
@ -141,4 +148,4 @@ class VoiceRSSProvider(Provider):
|
|||
if request is not None:
|
||||
yield from request.release()
|
||||
|
||||
return (self.extension, data)
|
||||
return (self._extension, data)
|
||||
|
|
|
@ -22,7 +22,7 @@ class TestTTS(object):
|
|||
def setup_method(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
self.demo_provider = DemoProvider()
|
||||
self.demo_provider = DemoProvider('en')
|
||||
self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR)
|
||||
|
||||
def teardown_method(self):
|
||||
|
@ -95,7 +95,7 @@ class TestTTS(object):
|
|||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'language': 'lang'
|
||||
'language': 'de'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,11 +111,23 @@ class TestTTS(object):
|
|||
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
|
||||
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_lang_demo.mp3") \
|
||||
"_de_demo.mp3") \
|
||||
!= -1
|
||||
assert os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3"))
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3"))
|
||||
|
||||
def test_setup_component_and_test_service_with_wrong_conf_language(self):
|
||||
"""Setup the demo platform and call service with wrong config."""
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
'language': 'ru'
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(0, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
def test_setup_component_and_test_service_with_service_language(self):
|
||||
"""Setup the demo platform and call service."""
|
||||
|
@ -127,6 +139,35 @@ class TestTTS(object):
|
|||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
tts.ATTR_LANGUAGE: "de",
|
||||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
|
||||
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_de_demo.mp3") \
|
||||
!= -1
|
||||
assert os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3"))
|
||||
|
||||
def test_setup_component_test_service_with_wrong_service_language(self):
|
||||
"""Setup the demo platform and call service."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
config = {
|
||||
tts.DOMAIN: {
|
||||
'platform': 'demo',
|
||||
}
|
||||
}
|
||||
|
||||
with assert_setup_component(1, tts.DOMAIN):
|
||||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
|
@ -136,13 +177,8 @@ class TestTTS(object):
|
|||
})
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
|
||||
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
|
||||
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
|
||||
"_lang_demo.mp3") \
|
||||
!= -1
|
||||
assert os.path.isfile(os.path.join(
|
||||
assert len(calls) == 0
|
||||
assert not os.path.isfile(os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3"))
|
||||
|
||||
|
@ -198,7 +234,7 @@ class TestTTS(object):
|
|||
|
||||
assert len(calls) == 1
|
||||
req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
|
||||
assert req.status_code == 200
|
||||
assert req.content == demo_data
|
||||
|
||||
|
@ -319,7 +355,7 @@ class TestTTS(object):
|
|||
"""Setup demo platform with cache and call service without cache."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
|
||||
cache_file = os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")
|
||||
|
@ -339,7 +375,7 @@ class TestTTS(object):
|
|||
setup_component(self.hass, tts.DOMAIN, config)
|
||||
|
||||
with patch('homeassistant.components.tts.demo.DemoProvider.'
|
||||
'get_tts_audio', return_value=None):
|
||||
'get_tts_audio', return_value=(None, None)):
|
||||
self.hass.services.call(tts.DOMAIN, 'demo_say', {
|
||||
tts.ATTR_MESSAGE: "I person is on front of your door.",
|
||||
})
|
||||
|
@ -352,7 +388,7 @@ class TestTTS(object):
|
|||
!= -1
|
||||
|
||||
@patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio',
|
||||
return_value=None)
|
||||
return_value=(None, None))
|
||||
def test_setup_component_test_with_error_on_get_tts(self, tts_mock):
|
||||
"""Setup demo platform with wrong get_tts_audio."""
|
||||
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
|
||||
|
@ -375,7 +411,7 @@ class TestTTS(object):
|
|||
|
||||
def test_setup_component_load_cache_retrieve_without_mem_cache(self):
|
||||
"""Setup component and load cache and get without mem cache."""
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla")
|
||||
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
|
||||
cache_file = os.path.join(
|
||||
self.default_tts_cache,
|
||||
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")
|
||||
|
|
Loading…
Reference in New Issue