small fixes

pull/10/head
Eren Golge 2019-11-19 16:48:04 +01:00
parent 448a70823d
commit a31d1cb2d9
2 changed files with 12 additions and 14 deletions

View File

@ -43,7 +43,7 @@
// VALIDATION // VALIDATION
"run_eval": true, "run_eval": true,
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time. "test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
// OPTIMIZER // OPTIMIZER

View File

@ -8,7 +8,6 @@ import traceback
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset from TTS.datasets.TTSDataset import MyDataset
@ -171,7 +170,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
stop_targets) if c.stopnet else torch.zeros(1) stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input, postnet_loss = criterion(postnet_output, linear_input,
mel_lengths) mel_lengths)
else: else:
@ -179,7 +178,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
mel_lengths) mel_lengths)
else: else:
decoder_loss = criterion(decoder_output, mel_input) decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input) postnet_loss = criterion(postnet_output, linear_input)
else: else:
postnet_loss = criterion(postnet_output, mel_input) postnet_loss = criterion(postnet_output, mel_input)
@ -277,7 +276,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# Diagnostic visualizations # Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
"Tacotron" "Tacotron", "TacotronGST"
] else mel_input[0].data.cpu().numpy() ] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
@ -293,7 +292,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_spec.T) train_audio = ap.inv_spectrogram(const_spec.T)
else: else:
train_audio = ap.inv_mel_spectrogram(const_spec.T) train_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -370,7 +369,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
if c.loss_masking: if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, decoder_loss = criterion(decoder_output, mel_input,
mel_lengths) mel_lengths)
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input, postnet_loss = criterion(postnet_output, linear_input,
mel_lengths) mel_lengths)
else: else:
@ -378,7 +377,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
mel_lengths) mel_lengths)
else: else:
decoder_loss = criterion(decoder_output, mel_input) decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input) postnet_loss = criterion(postnet_output, linear_input)
else: else:
postnet_loss = criterion(postnet_output, mel_input) postnet_loss = criterion(postnet_output, mel_input)
@ -434,7 +433,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy() const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
"Tacotron" "Tacotron", "TacotronGST"
] else mel_input[idx].data.cpu().numpy() ] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
@ -445,7 +444,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
} }
# Sample audio # Sample audio
if c.model in ["Tacotron"]: if c.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_spec.T) eval_audio = ap.inv_spectrogram(const_spec.T)
else: else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T) eval_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -466,7 +465,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
tb_logger.tb_eval_stats(global_step, epoch_stats) tb_logger.tb_eval_stats(global_step, epoch_stats)
tb_logger.tb_eval_figures(global_step, eval_figures) tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs: if args.rank == 0 and epoch > c.test_delay_epochs:
if c.test_sentences_file is None: if c.test_sentences_file is None:
test_sentences = [ test_sentences = [
@ -562,10 +560,10 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None optimizer_st = None
if c.loss_masking: if c.loss_masking:
criterion = L1LossMasked() if c.model in ["Tacotron" criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
] else MSELossMasked() ] else MSELossMasked()
else: else:
criterion = nn.L1Loss() if c.model in ["Tacotron" criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss() ] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss( criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(10)) if c.stopnet else None pos_weight=torch.tensor(10)) if c.stopnet else None
@ -686,7 +684,7 @@ if __name__ == '__main__':
args.restore_path = latest_model_file args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}") print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs # setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))