mirror of https://github.com/coqui-ai/TTS.git
added missing phonemes, synthesizer.py now setup the correct input layer
parent
0a3dba4279
commit
95de2cd559
|
@ -1,16 +1,13 @@
|
|||
import io
|
||||
import os
|
||||
import librosa
|
||||
import torch
|
||||
import scipy
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from utils.text import text_to_sequence
|
||||
from utils.generic_utils import load_config
|
||||
from utils.audio import AudioProcessor
|
||||
from models.tacotron import Tacotron
|
||||
from matplotlib import pylab as plt
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from models.tacotron import Tacotron
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.generic_utils import load_config
|
||||
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence
|
||||
|
||||
class Synthesizer(object):
|
||||
def load_model(self, model_path, model_name, model_config, use_cuda):
|
||||
|
@ -22,14 +19,22 @@ class Synthesizer(object):
|
|||
config = load_config(model_config)
|
||||
self.config = config
|
||||
self.use_cuda = use_cuda
|
||||
self.use_phonemes = config.use_phonemes
|
||||
self.ap = AudioProcessor(**config.audio)
|
||||
self.model = Tacotron(config.embedding_size, self.ap.num_freq, self.ap.num_mels, config.r)
|
||||
|
||||
if self.use_phonemes:
|
||||
self.input_size = len(phonemes)
|
||||
self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.config.text_cleaner], self.config.phoneme_language)
|
||||
else:
|
||||
self.input_size = len(symbols)
|
||||
self.input_adapter = lambda sen: text_to_sequence(sen, [self.config.text_cleaner])
|
||||
|
||||
self.model = Tacotron(self.input_size, config.embedding_size, self.ap.num_freq, self.ap.num_mels, config.r)
|
||||
# load model state
|
||||
if use_cuda:
|
||||
cp = torch.load(self.model_file)
|
||||
else:
|
||||
cp = torch.load(
|
||||
self.model_file, map_location=lambda storage, loc: storage)
|
||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
||||
# load the model
|
||||
self.model.load_state_dict(cp['model'])
|
||||
if use_cuda:
|
||||
|
@ -42,7 +47,6 @@ class Synthesizer(object):
|
|||
self.ap.save_wav(wav, path)
|
||||
|
||||
def tts(self, text):
|
||||
text_cleaner = [self.config.text_cleaner]
|
||||
wavs = []
|
||||
for sen in text.split('.'):
|
||||
if len(sen) < 3:
|
||||
|
@ -51,7 +55,9 @@ class Synthesizer(object):
|
|||
sen += '.'
|
||||
print(sen)
|
||||
sen = sen.strip()
|
||||
seq = np.array(text_to_sequence(sen, text_cleaner))
|
||||
|
||||
seq = np.array(self.input_adapter(sen))
|
||||
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
||||
if self.use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
|
@ -59,8 +65,9 @@ class Synthesizer(object):
|
|||
chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = self.ap.inv_spectrogram(linear_out.T)
|
||||
out = io.BytesIO()
|
||||
wavs += list(wav)
|
||||
wavs += [0] * 10000
|
||||
|
||||
out = io.BytesIO()
|
||||
self.save_wav(wavs, out)
|
||||
return out
|
||||
return out
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
import unittest
|
||||
|
||||
from utils.text import phonemes
|
||||
|
||||
class SymbolsTest(unittest.TestCase):
|
||||
def test_uniqueness(self):
|
||||
assert sorted(phonemes) == sorted(list(set(phonemes)))
|
|
@ -5,7 +5,6 @@ Defines the set of symbols used in text input to the model.
|
|||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
'''
|
||||
from utils.text import cmudict
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
|
@ -13,22 +12,24 @@ _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
|||
_punctuations = '!\'(),-.:;? '
|
||||
_phoneme_punctuations = '.!;:,?'
|
||||
|
||||
# TODO: include more phoneme characters for other languages.
|
||||
_phonemes = ['l','ɹ','ɜ','ɚ','k','u','ʔ','ð','ɐ','ɾ','ɑ','ɔ','b','ɛ','t','v','n','m','ʊ','ŋ','s',
|
||||
'ʌ','o','ʃ','i','p','æ','e','a','ʒ',' ','h','ɪ','ɡ','f','r','w','ɫ','ɬ','d','x','ː',
|
||||
'ᵻ','ə','j','θ','z','ɒ']
|
||||
|
||||
_phonemes = sorted(list(set(_phonemes)))
|
||||
# Phonemes definition
|
||||
_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
|
||||
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
||||
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
||||
_suprasegmentals = 'ˈˌːˑ'
|
||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
|
||||
_diacrilics = 'ɚ˞ɫ'
|
||||
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in _phonemes]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
||||
phonemes = [_pad, _eos] + list(_phonemes) + list(_punctuations)
|
||||
phonemes = [_pad, _eos] + _phonemes + list(_punctuations)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(" > TTS symbols ")
|
||||
print(" > TTS symbols {}".format(len(symbols)))
|
||||
print(symbols)
|
||||
print(" > TTS phonemes ")
|
||||
print(" > TTS phonemes {}".format(len(phonemes)))
|
||||
print(phonemes)
|
||||
|
|
Loading…
Reference in New Issue