Merge branch 'dev' of https://github.com/mozilla/TTS into dev

pull/10/head
erogol 2020-11-09 13:31:12 +01:00
commit 21364331d2
2 changed files with 11 additions and 8 deletions

View File

@ -10,7 +10,7 @@ import time
import torch
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.generic_utils import setup_model, is_tacotron
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
@ -125,7 +125,8 @@ if __name__ == "__main__":
model.eval()
if args.use_cuda:
model.cuda()
model.decoder.set_r(cp['r'])
if is_tacotron(C):
model.decoder.set_r(cp['r'])
# load vocoder model
if args.vocoder_path != "":
@ -153,7 +154,10 @@ if __name__ == "__main__":
args.speaker_fileid = None
if args.gst_style is None:
gst_style = C.gst['gst_style_input']
if is_tacotron(C):
gst_style = C.gst['gst_style_input']
else:
gst_style = None
else:
# check if gst_style string is a dict, if is dict convert else use string
try:

View File

@ -28,7 +28,6 @@ def split_dataset(items):
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
@ -50,7 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron":
model = MyModel(num_chars=num_chars,
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
@ -77,7 +76,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars,
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
@ -103,7 +102,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars,
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels=192,
filter_channels=768,
filter_channels_dp=256,
@ -131,7 +130,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
return model
def is_tacotron(c):
return False if c['model'] == 'glow_tts' else True
return False if 'glow_tts' in c['model'] else True
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)