mirror of https://github.com/coqui-ai/TTS.git
Use step wise LR scheduler + adapt train.py for passing squence mask directly
parent
e0bce1d2d1
commit
96e2e3c776
88
train.py
88
train.py
|
@ -16,11 +16,12 @@ from tensorboardX import SummaryWriter
|
|||
from utils.generic_utils import (
|
||||
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||
check_update, get_commit_hash)
|
||||
check_update, get_commit_hash, sequence_mask)
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
torch.manual_seed(1)
|
||||
torch.set_num_threads(4)
|
||||
|
@ -28,7 +29,7 @@ use_cuda = torch.cuda.is_available()
|
|||
|
||||
|
||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||
ap, epoch):
|
||||
scheduler, ap, epoch):
|
||||
model = model.train()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
|
@ -58,15 +59,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
epoch * len(data_loader) + 1
|
||||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
|
||||
for params_group in optimizer_st.param_groups:
|
||||
params_group['lr'] = current_lr_st
|
||||
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
optimizer_st.zero_grad()
|
||||
|
||||
|
@ -79,9 +72,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
linear_input = linear_input.cuda()
|
||||
stop_targets = stop_targets.cuda()
|
||||
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths)
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_input, text_lengths)
|
||||
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
||||
model, (text_input, mel_input, mask))
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
@ -94,7 +90,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# backpass and check the grad norm for spec losses
|
||||
loss.backward(retain_graph=True)
|
||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||
grad_norm, skip_flag = check_update(model, 1)
|
||||
if skip_flag:
|
||||
optimizer.zero_grad()
|
||||
print(" | > Iteration skipped!!", flush=True)
|
||||
|
@ -103,8 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# backpass and check the grad norm for stop loss
|
||||
stop_loss.backward()
|
||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
|
||||
0.5, 100)
|
||||
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||
if skip_flag:
|
||||
optimizer_st.zero_grad()
|
||||
print(" | | > Iteration skipped fro stopnet!!")
|
||||
|
@ -115,18 +110,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
epoch_time += step_time
|
||||
|
||||
if current_step % c.print_step == 0:
|
||||
print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
|
||||
batch_n_iter,
|
||||
current_step,
|
||||
loss.item(),
|
||||
linear_loss.item(),
|
||||
mel_loss.item(),
|
||||
stop_loss.item(),
|
||||
grad_norm.item(),
|
||||
grad_norm_st.item(),
|
||||
step_time), flush=True)
|
||||
print(
|
||||
" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(
|
||||
num_iter, batch_n_iter, current_step, loss.item(),
|
||||
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||
grad_norm.item(), grad_norm_st.item(), step_time),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
|
@ -184,16 +175,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_step_time /= (num_iter + 1)
|
||||
|
||||
# print epoch stats
|
||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f}".format(current_step,
|
||||
avg_total_loss,
|
||||
avg_linear_loss,
|
||||
avg_mel_loss,
|
||||
avg_stop_loss,
|
||||
epoch_time,
|
||||
avg_step_time), flush=True)
|
||||
print(
|
||||
" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||
avg_linear_loss, avg_mel_loss,
|
||||
avg_stop_loss, epoch_time, avg_step_time),
|
||||
flush=True)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||
|
@ -266,8 +255,10 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
if num_iter % c.print_step == 0:
|
||||
print(
|
||||
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
|
||||
mel_loss.item(), stop_loss.item()),
|
||||
"StopLoss: {:.5f} ".format(loss.item(),
|
||||
linear_loss.item(),
|
||||
mel_loss.item(),
|
||||
stop_loss.item()),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
|
@ -322,11 +313,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
ap.griffin_lim_iters = 60
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||
use_cuda, c.text_cleaner)
|
||||
wav_name = 'TestSentences/{}'.format(idx)
|
||||
tb.add_audio(
|
||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||
use_cuda, c.text_cleaner)
|
||||
wav_name = 'TestSentences/{}'.format(idx)
|
||||
tb.add_audio(
|
||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||
except:
|
||||
print(" !! Error as creating Test Sentence -", idx)
|
||||
pass
|
||||
|
@ -405,7 +396,7 @@ def main(args):
|
|||
checkpoint = torch.load(args.restore_path)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
@ -423,10 +414,11 @@ def main(args):
|
|||
args.restore_step = 0
|
||||
print("\n > Starting a new training", flush=True)
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
|
||||
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
|
@ -439,7 +431,7 @@ def main(args):
|
|||
for epoch in range(0, c.epochs):
|
||||
train_loss, current_step = train(model, criterion, criterion_st,
|
||||
train_loader, optimizer, optimizer_st,
|
||||
ap, epoch)
|
||||
scheduler, ap, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
||||
current_step)
|
||||
print(
|
||||
|
|
Loading…
Reference in New Issue