plot attention alignments

pull/10/head
Eren Golge 2018-02-02 05:37:09 -08:00
parent 235ce071c6
commit 4e0ab65bbf
5 changed files with 66 additions and 55 deletions

Binary file not shown.

View File

@ -53,7 +53,7 @@ class LJSpeechDataset(Dataset):
text = [d['text'] for d in batch]
text_lenghts = [len(x) for x in text]
max_text_len = np.max(text_lengths)
max_text_len = np.max(text_lenghts)
wav = [d['wav'] for d in batch]
# PAD sequences with largest length of the batch

View File

@ -55,6 +55,7 @@ class AttentionWrapper(nn.Module):
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
@ -73,7 +74,7 @@ class AttentionWrapper(nn.Module):
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
alignment = F.softmax(alignment, dim=-1) ## TODO: might be buggy
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)

File diff suppressed because one or more lines are too long

View File

@ -123,7 +123,7 @@ def main(args):
# setup lr
current_lr = lr_decay(c.lr, current_step)
for params_group in optimizer.param_groups:
param_group['lr'] = current_lr
params_group['lr'] = current_lr
optimizer.zero_grad()
@ -204,11 +204,14 @@ def main(args):
checkpoint_path)
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
# Log spectrogram reconstruction
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu()[None, :]
gt_spec = linear_spec_var[0].data.cpu()[None, :]
align_img = alignments[0].data.cpu().t()[None, :]
tb.add_image('Spec/Reconstruction', const_spec, current_step)
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
tb.add_image('Attn/Alignment', align_img, current_step)
#lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch)