mirror of https://github.com/coqui-ai/TTS.git
small fixes
parent
448a70823d
commit
a31d1cb2d9
|
@ -43,7 +43,7 @@
|
|||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
|
|
24
train.py
24
train.py
|
@ -8,7 +8,6 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.datasets.TTSDataset import MyDataset
|
||||
|
@ -171,7 +170,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
|
@ -179,7 +178,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
|
@ -277,7 +276,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
||||
"Tacotron"
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
|
@ -293,7 +292,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
|
@ -370,7 +369,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input,
|
||||
mel_lengths)
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
|
@ -378,7 +377,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
|
@ -434,7 +433,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||
"Tacotron"
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
|
@ -445,7 +444,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
}
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron"]:
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
|
@ -466,7 +465,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
|
||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||
if c.test_sentences_file is None:
|
||||
test_sentences = [
|
||||
|
@ -562,10 +560,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_st = None
|
||||
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron"
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else MSELossMasked()
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron"
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||
] else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss(
|
||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||
|
@ -686,7 +684,7 @@ if __name__ == '__main__':
|
|||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
|
Loading…
Reference in New Issue