Issues 356 - Refactoring TTS

pull/420/head
Jonathan D'Orleans 2016-09-16 14:12:43 -04:00
parent 885fe0a1cf
commit e76c1b7d21
8 changed files with 114 additions and 132 deletions

View File

@ -16,18 +16,10 @@
# along with Mycroft Core. If not, see <http://www.gnu.org/licenses/>.
import logging
import abc
from abc import ABCMeta, abstractmethod
from os.path import dirname, exists, isdir
from mycroft.configuration import ConfigurationManager
from mycroft.tts import espeak_tts
from mycroft.tts import fa_tts
from mycroft.tts import google_tts
from mycroft.tts import mary_tts
from mycroft.tts import mimic_tts
from mycroft.tts import spdsay_tts
from mycroft.util.log import getLogger
__author__ = 'jdorleans'
@ -42,15 +34,17 @@ class TTS(object):
It aggregates the minimum required parameters and exposes
``execute(sentence)`` function.
"""
__metaclass__ = ABCMeta
def __init__(self, lang, voice, filename='/tmp/tts.wav'):
def __init__(self, lang, voice, validator):
super(TTS, self).__init__()
self.lang = lang
self.voice = voice
self.filename = filename
self.filename = '/tmp/tts.wav'
self.validator = validator
@abc.abstractmethod
def execute(self, sentence, client):
@abstractmethod
def execute(self, sentence):
pass
@ -61,48 +55,61 @@ class TTSValidator(object):
It exposes and implements ``validate(tts)`` function as a template to
validate the TTS engines.
"""
__metaclass__ = ABCMeta
def __init__(self):
pass
def __init__(self, tts):
self.tts = tts
def validate(self, tts):
self.__validate_instance(tts)
self.__validate_filename(tts.filename)
self.validate_lang(tts.lang)
self.validate_connection(tts)
def validate(self):
self.validate_instance()
self.validate_filename()
self.validate_lang()
self.validate_connection()
def __validate_instance(self, tts):
instance = self.get_instance()
if not isinstance(tts, instance):
raise AttributeError(
'tts must be instance of ' + instance.__name__)
LOGGER.debug('TTS: ' + str(instance))
def validate_instance(self):
clazz = self.get_tts_class()
if not isinstance(self.tts, clazz):
raise AttributeError('tts must be instance of ' + clazz.__name__)
def __validate_filename(self, filename):
def validate_filename(self):
filename = self.tts.filename
if not (filename and filename.endswith('.wav')):
raise AttributeError(
'filename: ' + filename + ' must be a .wav file!')
raise AttributeError('file: %s must be in .wav format!' % filename)
dir_path = dirname(filename)
if not (exists(dir_path) and isdir(dir_path)):
raise AttributeError(
'filename: ' + filename + ' is not a valid file path!')
LOGGER.debug('Filename: ' + filename)
raise AttributeError('filename: %s is not valid!' % filename)
@abc.abstractmethod
def validate_lang(self, lang):
@abstractmethod
def validate_lang(self):
pass
@abc.abstractmethod
def validate_connection(self, tts):
@abstractmethod
def validate_connection(self):
pass
@abc.abstractmethod
def get_instance(self):
@abstractmethod
def get_tts_class(self):
pass
class TTSFactory(object):
from mycroft.tts.espeak_tts import ESpeak
from mycroft.tts.fa_tts import FATTS
from mycroft.tts.google_tts import GoogleTTS
from mycroft.tts.mary_tts import MaryTTS
from mycroft.tts.mimic_tts import Mimic
from mycroft.tts.spdsay_tts import SpdSay
CLASSES = {
"mimic": Mimic,
"google": GoogleTTS,
"marytts": MaryTTS,
"fatts": FATTS,
"espeak": ESpeak,
"spdsay": SpdSay
}
@staticmethod
def create():
"""
@ -116,28 +123,18 @@ class TTSFactory(object):
}
"""
logging.basicConfig()
config = ConfigurationManager.get().get('tts')
name = config.get('module')
lang = config.get(name).get('lang')
voice = config.get(name).get('voice')
from mycroft.tts.remote_tts import RemoteTTS
config = ConfigurationManager.get().get('tts', {})
module = config.get('module', 'mimic')
lang = config.get(module).get('lang')
voice = config.get(module).get('voice')
clazz = TTSFactory.CLASSES.get(module)
if name == mimic_tts.NAME:
tts = mimic_tts.Mimic(lang, voice)
mimic_tts.MimicValidator().validate(tts)
elif name == google_tts.NAME:
tts = google_tts.GoogleTTS(lang, voice)
google_tts.GoogleTTSValidator().validate(tts)
elif name == mary_tts.NAME:
tts = mary_tts.MaryTTS(lang, voice, config[name + '.url'])
mary_tts.MaryTTSValidator().validate(tts)
elif name == fa_tts.NAME:
tts = fa_tts.FATTS(lang, voice, config[name + '.url'])
fa_tts.FATTSValidator().validate(tts)
elif name == espeak_tts.NAME:
tts = espeak_tts.ESpeak(lang, voice)
espeak_tts.ESpeakValidator().validate(tts)
if issubclass(clazz, RemoteTTS):
url = config.get(module).get('url')
tts = clazz(lang, voice, url)
else:
tts = spdsay_tts.SpdSay(lang, voice)
spdsay_tts.SpdSayValidator().validate(tts)
tts = clazz(lang, voice)
tts.validator.validate()
return tts

View File

@ -20,14 +20,12 @@ import subprocess
from mycroft.tts import TTS, TTSValidator
__author__ = 'seanfitz'
NAME = 'espeak'
__author__ = 'seanfitz', 'jdorleans'
class ESpeak(TTS):
def __init__(self, lang, voice):
super(ESpeak, self).__init__(lang, voice)
super(ESpeak, self).__init__(lang, voice, ESpeakValidator(self))
def execute(self, sentence, client):
subprocess.call(
@ -35,20 +33,19 @@ class ESpeak(TTS):
class ESpeakValidator(TTSValidator):
def __init__(self):
super(ESpeakValidator, self).__init__()
def __init__(self, tts):
super(ESpeakValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
subprocess.call(['espeak', '--version'])
except:
raise Exception(
'ESpeak is not installed. Run on terminal: sudo apt-get '
'install espeak')
'ESpeak is not installed. Run: sudo apt-get install espeak')
def get_instance(self):
def get_tts_class(self):
return ESpeak

View File

@ -16,8 +16,6 @@
# along with Mycroft Core. If not, see <http://www.gnu.org/licenses/>.
import json
import requests
from mycroft.tts import TTSValidator
@ -25,8 +23,6 @@ from mycroft.tts.remote_tts import RemoteTTS
__author__ = 'jdorleans'
NAME = 'fatts'
class FATTS(RemoteTTS):
PARAMS = {
@ -39,7 +35,8 @@ class FATTS(RemoteTTS):
}
def __init__(self, lang, voice, url):
super(FATTS, self).__init__(lang, voice, url, '/say')
super(FATTS, self).__init__(lang, voice, url, '/say',
FATTSValidator(self))
def build_request_params(self, sentence):
params = self.PARAMS.copy()
@ -50,23 +47,23 @@ class FATTS(RemoteTTS):
class FATTSValidator(TTSValidator):
def __init__(self):
super(FATTSValidator, self).__init__()
def __init__(self, tts):
super(FATTSValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
resp = requests.get(tts.url + "/info/version", verify=False)
content = json.loads(resp.content)
if content['product'].find('FA-TTS') < 0:
resp = requests.get(self.tts.url + "/info/version", verify=False)
content = resp.json()
if content.get('product', '').find('FA-TTS') < 0:
raise Exception('Invalid FA-TTS server.')
except:
raise Exception(
'FA-TTS server could not be verified. Check your connection '
'to the server: ' + tts.url)
'to the server: ' + self.tts.url)
def get_instance(self):
def get_tts_class(self):
return FATTS

View File

@ -23,12 +23,10 @@ from mycroft.util import play_wav
__author__ = 'jdorleans'
NAME = 'gtts'
class GoogleTTS(TTS):
def __init__(self, lang, voice):
super(GoogleTTS, self).__init__(lang, voice)
super(GoogleTTS, self).__init__(lang, voice, GoogleTTSValidator(self))
def execute(self, sentence, client):
tts = gTTS(text=sentence, lang=self.lang)
@ -37,20 +35,20 @@ class GoogleTTS(TTS):
class GoogleTTSValidator(TTSValidator):
def __init__(self):
super(GoogleTTSValidator, self).__init__()
def __init__(self, tts):
super(GoogleTTSValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
gTTS(text='Hi').save(tts.filename)
gTTS(text='Hi').save(self.tts.filename)
except:
raise Exception(
'GoogleTTS server could not be verified. Please check your '
'internet connection.')
def get_instance(self):
def get_tts_class(self):
return GoogleTTS

View File

@ -23,8 +23,6 @@ from mycroft.tts.remote_tts import RemoteTTS
__author__ = 'jdorleans'
NAME = 'marytts'
class MaryTTS(RemoteTTS):
PARAMS = {
@ -37,7 +35,8 @@ class MaryTTS(RemoteTTS):
}
def __init__(self, lang, voice, url):
super(MaryTTS, self).__init__(lang, voice, url, '/process')
super(MaryTTS, self).__init__(lang, voice, url, '/process',
MaryTTSValidator(self))
def build_request_params(self, sentence):
params = self.PARAMS.copy()
@ -48,22 +47,22 @@ class MaryTTS(RemoteTTS):
class MaryTTSValidator(TTSValidator):
def __init__(self):
super(MaryTTSValidator, self).__init__()
def __init__(self, tts):
super(MaryTTSValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
resp = requests.get(tts.url + "/version", verify=False)
resp = requests.get(self.tts.url + "/version", verify=False)
if resp.content.find('Mary TTS server') < 0:
raise Exception('Invalid MaryTTS server.')
except:
raise Exception(
'MaryTTS server could not be verified. Check your connection '
'to the server: ' + tts.url)
'to the server: ' + self.tts.url)
def get_instance(self):
def get_tts_class(self):
return MaryTTS

View File

@ -16,11 +16,10 @@
# along with Mycroft Core. If not, see <http://www.gnu.org/licenses/>.
import subprocess
from os.path import join
import re
import random
import os
import time
from os.path import join
from mycroft import MYCROFT_ROOT_PATH
from mycroft.tts import TTS, TTSValidator
@ -31,9 +30,7 @@ __author__ = 'jdorleans'
config = ConfigurationManager.get().get("tts", {})
NAME = 'mimic'
BIN = config.get(
"mimic.path", join(MYCROFT_ROOT_PATH, 'mimic', 'bin', 'mimic'))
BIN = config.get("path", join(MYCROFT_ROOT_PATH, 'mimic', 'bin', 'mimic'))
# Mapping based on Jeffers phoneme to viseme map, seen in table 1 from:
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.221.6377&rep=rep1&type=pdf
@ -49,7 +46,7 @@ BIN = config.get(
class Mimic(TTS):
def __init__(self, lang, voice):
super(Mimic, self).__init__(lang, voice)
super(Mimic, self).__init__(lang, voice, MimicValidator(self))
self.args = ['-voice', self.voice]
stretch = config.get('duration_stretch', None)
if stretch:
@ -151,19 +148,19 @@ class Mimic(TTS):
class MimicValidator(TTSValidator):
def __init__(self):
super(MimicValidator, self).__init__()
def __init__(self, tts):
super(MimicValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
subprocess.call([BIN, '--version'])
except:
raise Exception(
'Mimic is not installed. Make sure install-mimic.sh ran '
'properly.')
'Mimic is not installed. Run install-mimic.sh to install it.')
def get_instance(self):
def get_tts_class(self):
return Mimic

View File

@ -18,7 +18,6 @@
import abc
import re
from requests_futures.sessions import FuturesSession
from mycroft.tts import TTS
@ -38,8 +37,8 @@ class RemoteTTS(TTS):
whole sentence into small ones.
"""
def __init__(self, lang, voice, url, api_path):
super(RemoteTTS, self).__init__(lang, voice)
def __init__(self, lang, voice, url, api_path, validator):
super(RemoteTTS, self).__init__(lang, voice, validator)
self.api_path = api_path
self.url = remove_last_slash(url)
self.session = FuturesSession()

View File

@ -22,12 +22,10 @@ from mycroft.tts import TTS, TTSValidator
__author__ = 'jdorleans'
NAME = 'spdsay'
class SpdSay(TTS):
def __init__(self, lang, voice):
super(SpdSay, self).__init__(lang, voice)
super(SpdSay, self).__init__(lang, voice, SpdSayValidator(self))
def execute(self, sentence, client):
subprocess.call(
@ -35,20 +33,20 @@ class SpdSay(TTS):
class SpdSayValidator(TTSValidator):
def __init__(self):
super(SpdSayValidator, self).__init__()
def __init__(self, tts):
super(SpdSayValidator, self).__init__(tts)
def validate_lang(self, lang):
def validate_lang(self):
# TODO
pass
def validate_connection(self, tts):
def validate_connection(self):
try:
subprocess.call(['spd-say', '--version'])
except:
raise Exception(
'SpdSay is not installed. Run on terminal: sudo apt-get'
'install speech-dispatcher')
'SpdSay is not installed. Run: sudo apt-get install '
'speech-dispatcher')
def get_instance(self):
def get_tts_class(self):
return SpdSay