mirror of https://github.com/coqui-ai/TTS.git
small clean
parent
a2a2065bb4
commit
405fbc434e
41
train.py
41
train.py
|
@ -31,8 +31,6 @@ from models.tacotron import Tacotron
|
|||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--restore_step', type=int,
|
||||
help='Global step to restore checkpoint', default=0)
|
||||
parser.add_argument('--restore_path', type=str,
|
||||
help='Folder path to checkpoints', default=0)
|
||||
parser.add_argument('--config_path', type=str,
|
||||
|
@ -67,6 +65,8 @@ def signal_handler(signal, frame):
|
|||
def train(model, criterion, data_loader, optimizer, epoch):
|
||||
model = model.train()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||
|
@ -180,11 +180,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
print(audio_signal.max())
|
||||
print(audio_signal.min())
|
||||
|
||||
avg_linear_loss = np.mean(
|
||||
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
|
||||
avg_mel_loss = np.mean(
|
||||
progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1]))
|
||||
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
||||
|
@ -203,6 +202,9 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
print("\n | > Validation")
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||
|
||||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
@ -242,19 +244,22 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||
('linear_loss', linear_loss.data[0]),
|
||||
('mel_loss', mel_loss.data[0])])
|
||||
|
||||
avg_linear_loss += linear_loss.data[0]
|
||||
avg_mel_loss += avg_mel_loss.data[0]
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(c.batch_size)
|
||||
const_spec = linear_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_spec_var[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
|
||||
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
||||
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
|
||||
|
||||
# Sample audio
|
||||
|
@ -270,10 +275,8 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
print(audio_signal.min())
|
||||
|
||||
# compute average losses
|
||||
avg_linear_loss = np.mean(
|
||||
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
|
||||
avg_mel_loss = np.mean(
|
||||
progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1]))
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||
|
||||
# Plot Learning Stats
|
||||
|
@ -339,15 +342,7 @@ def main(args):
|
|||
else:
|
||||
criterion = nn.L1Loss()
|
||||
|
||||
if args.restore_step:
|
||||
checkpoint = torch.load(os.path.join(
|
||||
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||
start_epoch = checkpoint['step'] // len(train_loader)
|
||||
best_loss = checkpoint['linear_loss']
|
||||
elif args.restore_path:
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
|
Loading…
Reference in New Issue