small clean

pull/10/head
Eren Golge 2018-03-06 05:39:54 -08:00
parent a2a2065bb4
commit 405fbc434e
1 changed files with 18 additions and 23 deletions

View File

@ -31,8 +31,6 @@ from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=0)
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
@ -67,6 +65,8 @@ def signal_handler(signal, frame):
def train(model, criterion, data_loader, optimizer, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
avg_mel_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
@ -180,11 +180,10 @@ def train(model, criterion, data_loader, optimizer, epoch):
print(audio_signal.max())
print(audio_signal.min())
avg_linear_loss = np.mean(
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
avg_mel_loss = np.mean(
progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1]))
avg_total_loss = avg_mel_loss + avg_linear_loss
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
@ -203,6 +202,9 @@ def evaluate(model, criterion, data_loader, current_step):
print("\n | > Validation")
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
avg_linear_loss = 0
avg_mel_loss = 0
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -242,19 +244,22 @@ def evaluate(model, criterion, data_loader, current_step):
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0])])
avg_linear_loss += linear_loss.data[0]
avg_mel_loss += avg_mel_loss.data[0]
# Diagnostic visualizations
idx = np.random.randint(c.batch_size)
const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_spec_var[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
align_img = plot_alignment(align_img)
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
align_img = alignments[idx].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
# Sample audio
@ -270,10 +275,8 @@ def evaluate(model, criterion, data_loader, current_step):
print(audio_signal.min())
# compute average losses
avg_linear_loss = np.mean(
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
avg_mel_loss = np.mean(
progbar.sum_values['mel_loss'][0] / max(1, progbar.sum_values['mel_loss'][1]))
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
# Plot Learning Stats
@ -339,15 +342,7 @@ def main(args):
else:
criterion = nn.L1Loss()
if args.restore_step:
checkpoint = torch.load(os.path.join(
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step)
start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss']
elif args.restore_path:
if args.restore_path:
checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])