Cleanup language support on TTS (#5255)

* Cleanup language support on TTS

* change to default_language & address comments

* Cleanup not needed code / comment from paulus
pull/3734/merge
Pascal Vizeli 2017-01-11 16:31:16 +01:00 committed by Paulus Schoutsen
parent 467cb18625
commit 3f3a3bcc8a
6 changed files with 156 additions and 59 deletions

View File

@ -5,7 +5,6 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/tts/
"""
import asyncio
import functools
import hashlib
import logging
import mimetypes
@ -247,8 +246,6 @@ class SpeechManager(object):
def async_register_engine(self, engine, provider, config):
"""Register a TTS provider."""
provider.hass = self.hass
if CONF_LANG in config:
provider.language = config.get(CONF_LANG)
self.providers[engine] = provider
@asyncio.coroutine
@ -257,9 +254,16 @@ class SpeechManager(object):
This method is a coroutine.
"""
provider = self.providers[engine]
language = language or provider.default_language
if language is None or \
language not in provider.supported_languages:
raise HomeAssistantError("Not supported language {0}".format(
language))
msg_hash = hashlib.sha1(bytes(message, 'utf-8')).hexdigest()
language_key = language or self.providers[engine].language
key = KEY_PATTERN.format(msg_hash, language_key, engine).lower()
key = KEY_PATTERN.format(msg_hash, language, engine).lower()
use_cache = cache if cache is not None else self.use_cache
# is speech allready in memory
@ -387,13 +391,22 @@ class Provider(object):
"""Represent a single provider."""
hass = None
language = None
def get_tts_audio(self, message, language=None):
@property
def default_language(self):
"""Default language."""
return None
@property
def supported_languages(self):
"""List of supported languages."""
return None
def get_tts_audio(self, message, language):
"""Load tts audio file from provider."""
raise NotImplementedError()
def async_get_tts_audio(self, message, language=None):
def async_get_tts_audio(self, message, language):
"""Load tts audio file from provider.
Return a tuple of file extension and data as bytes.
@ -401,8 +414,7 @@ class Provider(object):
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(
None,
functools.partial(self.get_tts_audio, message, language=language))
None, self.get_tts_audio, message, language)
class TextToSpeechView(HomeAssistantView):

View File

@ -6,28 +6,50 @@ https://home-assistant.io/components/demo/
"""
import os
from homeassistant.components.tts import Provider
import voluptuous as vol
from homeassistant.components.tts import Provider, PLATFORM_SCHEMA, CONF_LANG
SUPPORT_LANGUAGES = [
'en', 'de'
]
DEFAULT_LANG = 'en'
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES),
})
def get_engine(hass, config):
"""Setup Demo speech component."""
return DemoProvider()
return DemoProvider(config[CONF_LANG])
class DemoProvider(Provider):
"""Demo speech api provider."""
def __init__(self):
"""Initialize demo provider for TTS."""
self.language = 'en'
def __init__(self, lang):
"""Initialize demo provider."""
self._lang = lang
def get_tts_audio(self, message, language=None):
@property
def default_language(self):
"""Default language."""
return self._lang
@property
def supported_languages(self):
"""List of supported languages."""
return SUPPORT_LANGUAGES
def get_tts_audio(self, message, language):
"""Load TTS from demo."""
filename = os.path.join(os.path.dirname(__file__), "demo.mp3")
try:
with open(filename, 'rb') as voice:
data = voice.read()
except OSError:
return
return (None, None)
return ("mp3", data)

View File

@ -42,15 +42,16 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
@asyncio.coroutine
def async_get_engine(hass, config):
"""Setup Google speech component."""
return GoogleProvider(hass)
return GoogleProvider(hass, config[CONF_LANG])
class GoogleProvider(Provider):
"""Google speech api provider."""
def __init__(self, hass):
def __init__(self, hass, lang):
"""Init Google TTS service."""
self.hass = hass
self._lang = lang
self.headers = {
'Referer': "http://translate.google.com/",
'User-Agent': ("Mozilla/5.0 (Windows NT 10.0; WOW64) "
@ -58,8 +59,18 @@ class GoogleProvider(Provider):
"Chrome/47.0.2526.106 Safari/537.36")
}
@property
def default_language(self):
"""Default language."""
return self._lang
@property
def supported_languages(self):
"""List of supported languages."""
return SUPPORT_LANGUAGES
@asyncio.coroutine
def async_get_tts_audio(self, message, language=None):
def async_get_tts_audio(self, message, language):
"""Load TTS from google."""
from gtts_token import gtts_token
@ -67,11 +78,6 @@ class GoogleProvider(Provider):
websession = async_get_clientsession(self.hass)
message_parts = self._split_message_to_parts(message)
# If language is not specified or is not supported - use the language
# from the config.
if language not in SUPPORT_LANGUAGES:
language = self.language
data = b''
for idx, part in enumerate(message_parts):
part_token = yield from self.hass.loop.run_in_executor(

View File

@ -29,18 +29,31 @@ def get_engine(hass, config):
if shutil.which("pico2wave") is None:
_LOGGER.error("'pico2wave' was not found")
return False
return PicoProvider()
return PicoProvider(config[CONF_LANG])
class PicoProvider(Provider):
"""pico speech api provider."""
def get_tts_audio(self, message, language=None):
def __init__(self, lang):
"""Initialize pico provider."""
self._lang = lang
@property
def default_language(self):
"""Default language."""
return self._lang
@property
def supported_languages(self):
"""List of supported languages."""
return SUPPORT_LANGUAGES
def get_tts_audio(self, message, language):
"""Load TTS using pico2wave."""
if language not in SUPPORT_LANGUAGES:
language = self.language
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmpf:
fname = tmpf.name
cmd = ['pico2wave', '--wave', fname, '-l', language, message]
subprocess.call(cmd)
data = None
@ -52,6 +65,7 @@ class PicoProvider(Provider):
return (None, None)
finally:
os.remove(fname)
if data:
return ("wav", data)
return (None, None)

View File

@ -93,27 +93,34 @@ class VoiceRSSProvider(Provider):
def __init__(self, hass, conf):
"""Init VoiceRSS TTS service."""
self.hass = hass
self.extension = conf.get(CONF_CODEC)
self._extension = conf[CONF_CODEC]
self._lang = conf[CONF_LANG]
self.form_data = {
'key': conf.get(CONF_API_KEY),
'hl': conf.get(CONF_LANG),
'c': (conf.get(CONF_CODEC)).upper(),
'f': conf.get(CONF_FORMAT),
self._form_data = {
'key': conf[CONF_API_KEY],
'hl': conf[CONF_LANG],
'c': (conf[CONF_CODEC]).upper(),
'f': conf[CONF_FORMAT],
}
@property
def default_language(self):
"""Default language."""
return self._lang
@property
def supported_languages(self):
"""List of supported languages."""
return SUPPORT_LANGUAGES
@asyncio.coroutine
def async_get_tts_audio(self, message, language=None):
def async_get_tts_audio(self, message, language):
"""Load TTS from voicerss."""
websession = async_get_clientsession(self.hass)
form_data = self.form_data.copy()
form_data = self._form_data.copy()
form_data['src'] = message
# If language is specified and supported - use it instead of the
# language in the config.
if language in SUPPORT_LANGUAGES:
form_data['hl'] = language
form_data['hl'] = language
request = None
try:
@ -141,4 +148,4 @@ class VoiceRSSProvider(Provider):
if request is not None:
yield from request.release()
return (self.extension, data)
return (self._extension, data)

View File

@ -22,7 +22,7 @@ class TestTTS(object):
def setup_method(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.demo_provider = DemoProvider()
self.demo_provider = DemoProvider('en')
self.default_tts_cache = self.hass.config.path(tts.DEFAULT_CACHE_DIR)
def teardown_method(self):
@ -95,7 +95,7 @@ class TestTTS(object):
config = {
tts.DOMAIN: {
'platform': 'demo',
'language': 'lang'
'language': 'de'
}
}
@ -111,11 +111,23 @@ class TestTTS(object):
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
"_lang_demo.mp3") \
"_de_demo.mp3") \
!= -1
assert os.path.isfile(os.path.join(
self.default_tts_cache,
"265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3"))
"265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3"))
def test_setup_component_and_test_service_with_wrong_conf_language(self):
"""Setup the demo platform and call service with wrong config."""
config = {
tts.DOMAIN: {
'platform': 'demo',
'language': 'ru'
}
}
with assert_setup_component(0, tts.DOMAIN):
setup_component(self.hass, tts.DOMAIN, config)
def test_setup_component_and_test_service_with_service_language(self):
"""Setup the demo platform and call service."""
@ -127,6 +139,35 @@ class TestTTS(object):
}
}
with assert_setup_component(1, tts.DOMAIN):
setup_component(self.hass, tts.DOMAIN, config)
self.hass.services.call(tts.DOMAIN, 'demo_say', {
tts.ATTR_MESSAGE: "I person is on front of your door.",
tts.ATTR_LANGUAGE: "de",
})
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
"_de_demo.mp3") \
!= -1
assert os.path.isfile(os.path.join(
self.default_tts_cache,
"265944c108cbb00b2a621be5930513e03a0bb2cd_de_demo.mp3"))
def test_setup_component_test_service_with_wrong_service_language(self):
"""Setup the demo platform and call service."""
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
config = {
tts.DOMAIN: {
'platform': 'demo',
}
}
with assert_setup_component(1, tts.DOMAIN):
setup_component(self.hass, tts.DOMAIN, config)
@ -136,13 +177,8 @@ class TestTTS(object):
})
self.hass.block_till_done()
assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MEDIA_TYPE_MUSIC
assert calls[0].data[ATTR_MEDIA_CONTENT_ID].find(
"/api/tts_proxy/265944c108cbb00b2a621be5930513e03a0bb2cd"
"_lang_demo.mp3") \
!= -1
assert os.path.isfile(os.path.join(
assert len(calls) == 0
assert not os.path.isfile(os.path.join(
self.default_tts_cache,
"265944c108cbb00b2a621be5930513e03a0bb2cd_lang_demo.mp3"))
@ -198,7 +234,7 @@ class TestTTS(object):
assert len(calls) == 1
req = requests.get(calls[0].data[ATTR_MEDIA_CONTENT_ID])
_, demo_data = self.demo_provider.get_tts_audio("bla")
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
assert req.status_code == 200
assert req.content == demo_data
@ -319,7 +355,7 @@ class TestTTS(object):
"""Setup demo platform with cache and call service without cache."""
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
_, demo_data = self.demo_provider.get_tts_audio("bla")
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
cache_file = os.path.join(
self.default_tts_cache,
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")
@ -339,7 +375,7 @@ class TestTTS(object):
setup_component(self.hass, tts.DOMAIN, config)
with patch('homeassistant.components.tts.demo.DemoProvider.'
'get_tts_audio', return_value=None):
'get_tts_audio', return_value=(None, None)):
self.hass.services.call(tts.DOMAIN, 'demo_say', {
tts.ATTR_MESSAGE: "I person is on front of your door.",
})
@ -352,7 +388,7 @@ class TestTTS(object):
!= -1
@patch('homeassistant.components.tts.demo.DemoProvider.get_tts_audio',
return_value=None)
return_value=(None, None))
def test_setup_component_test_with_error_on_get_tts(self, tts_mock):
"""Setup demo platform with wrong get_tts_audio."""
calls = mock_service(self.hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
@ -375,7 +411,7 @@ class TestTTS(object):
def test_setup_component_load_cache_retrieve_without_mem_cache(self):
"""Setup component and load cache and get without mem cache."""
_, demo_data = self.demo_provider.get_tts_audio("bla")
_, demo_data = self.demo_provider.get_tts_audio("bla", 'en')
cache_file = os.path.join(
self.default_tts_cache,
"265944c108cbb00b2a621be5930513e03a0bb2cd_en_demo.mp3")