Update train.py for stop token prediciton

pull/10/head
Eren Golge 2018-05-11 04:15:53 -07:00
parent 8be07ee3c5
commit 31cc9c5443
1 changed files with 35 additions and 20 deletions

View File

@ -59,12 +59,12 @@ LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def train(model, criterion, data_loader, optimizer, epoch):
def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
avg_mel_loss = 0
avg_stop_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
@ -96,16 +96,17 @@ def train(model, criterion, data_loader, optimizer, epoch):
linear_input = linear_input.cuda()
# forward pass
mel_output, linear_output, alignments =\
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_target)
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
loss = mel_loss + linear_loss + stop_loss
# backpass and check the grad norm
loss.backward()
@ -123,9 +124,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item()),
('stop_loss', stop_loss.item()),
('grad_norm', grad_norm.item())])
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item()
# Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
@ -172,24 +175,26 @@ def train(model, criterion, data_loader, optimizer, epoch):
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
return avg_linear_loss, current_step
def evaluate(model, criterion, data_loader, current_step):
def evaluate(model, criterion, criterion_st, data_loader, current_step):
model = model.eval()
epoch_time = 0
avg_linear_loss = 0
avg_mel_loss = 0
avg_stop_loss = 0
print(" | > Validation")
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
@ -203,6 +208,11 @@ def evaluate(model, criterion, data_loader, current_step):
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
# set stop targets view, we predict a single stop token per r frames prediction
stop_target = stop_target.view(text_input.shape[0], stop_target.size(1) // c.r, -1)
stop_target = (stop_target.sum(2) > 0.0).float()
# dispatch data to GPU
if use_cuda:
@ -210,18 +220,20 @@ def evaluate(model, criterion, data_loader, current_step):
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
stop_target = stop_target.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input, mel_input)
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_spec)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_target)
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
loss = mel_loss + linear_loss + stop_loss
step_time = time.time() - start_time
epoch_time += step_time
@ -229,10 +241,12 @@ def evaluate(model, criterion, data_loader, current_step):
# update
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item())])
('mel_loss', mel_loss.item()),
('stop_loss', stop_loss.item())])
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item()
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
@ -263,13 +277,15 @@ def evaluate(model, criterion, data_loader, current_step):
# compute average losses
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + stop_loss
# Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
return avg_linear_loss
@ -324,10 +340,8 @@ def main(args):
optimizer = optim.Adam(model.parameters(), lr=c.lr)
if use_cuda:
criterion = L1LossMasked().cuda()
else:
criterion = L1LossMasked()
criterion = L1LossMasked()
criterion_st = nn.BCELoss()
if args.restore_path:
checkpoint = torch.load(args.restore_path)
@ -344,6 +358,8 @@ def main(args):
if use_cuda:
model = nn.DataParallel(model.cuda())
criterion.cuda()
criterion_st.cuda()
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params))
@ -356,15 +372,14 @@ def main(args):
for epoch in range(0, c.epochs):
train_loss, current_step = train(
model, criterion, train_loader, optimizer, epoch)
val_loss = evaluate(model, criterion, val_loader, current_step)
model, criterion, criterion_st, train_loader, optimizer, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
best_loss = save_best_model(model, optimizer, val_loss,
best_loss, OUT_PATH,
current_step, epoch)
if __name__ == '__main__':
# signal.signal(signal.SIGINT, signal_handler)
try:
main(args)
except KeyboardInterrupt: