small fixes

pull/10/head
Eren Golge 2019-11-19 16:48:04 +01:00
parent 448a70823d
commit a31d1cb2d9
2 changed files with 12 additions and 14 deletions

View File

@ -43,7 +43,7 @@
// VALIDATION
"run_eval": true,
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
// OPTIMIZER

View File

@ -8,7 +8,6 @@ import traceback
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset
@ -171,7 +170,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
else:
@ -179,7 +178,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -277,7 +276,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
"Tacotron"
"Tacotron", "TacotronGST"
] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
@ -293,7 +292,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -370,7 +369,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input,
mel_lengths)
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
else:
@ -378,7 +377,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -434,7 +433,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
"Tacotron"
"Tacotron", "TacotronGST"
] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
@ -445,7 +444,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
}
# Sample audio
if c.model in ["Tacotron"]:
if c.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -466,7 +465,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
tb_logger.tb_eval_stats(global_step, epoch_stats)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs:
if c.test_sentences_file is None:
test_sentences = [
@ -562,10 +560,10 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None
if c.loss_masking:
criterion = L1LossMasked() if c.model in ["Tacotron"
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
] else MSELossMasked()
else:
criterion = nn.L1Loss() if c.model in ["Tacotron"
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(10)) if c.stopnet else None
@ -686,7 +684,7 @@ if __name__ == '__main__':
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))