mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
commit
21364331d2
|
@ -10,7 +10,7 @@ import time
|
|||
|
||||
import torch
|
||||
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model, is_tacotron
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -125,7 +125,8 @@ if __name__ == "__main__":
|
|||
model.eval()
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
model.decoder.set_r(cp['r'])
|
||||
if is_tacotron(C):
|
||||
model.decoder.set_r(cp['r'])
|
||||
|
||||
# load vocoder model
|
||||
if args.vocoder_path != "":
|
||||
|
@ -153,7 +154,10 @@ if __name__ == "__main__":
|
|||
args.speaker_fileid = None
|
||||
|
||||
if args.gst_style is None:
|
||||
gst_style = C.gst['gst_style_input']
|
||||
if is_tacotron(C):
|
||||
gst_style = C.gst['gst_style_input']
|
||||
else:
|
||||
gst_style = None
|
||||
else:
|
||||
# check if gst_style string is a dict, if is dict convert else use string
|
||||
try:
|
||||
|
|
|
@ -28,7 +28,6 @@ def split_dataset(items):
|
|||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
|
@ -50,7 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.model))
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||
|
@ -77,7 +76,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
|
@ -103,7 +102,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
hidden_channels=192,
|
||||
filter_channels=768,
|
||||
filter_channels_dp=256,
|
||||
|
@ -131,7 +130,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
return model
|
||||
|
||||
def is_tacotron(c):
|
||||
return False if c['model'] == 'glow_tts' else True
|
||||
return False if 'glow_tts' in c['model'] else True
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
|
||||
|
|
Loading…
Reference in New Issue