mirror of https://github.com/coqui-ai/TTS.git
plot attention alignments
parent
235ce071c6
commit
4e0ab65bbf
Binary file not shown.
|
@ -53,7 +53,7 @@ class LJSpeechDataset(Dataset):
|
|||
|
||||
text = [d['text'] for d in batch]
|
||||
text_lenghts = [len(x) for x in text]
|
||||
max_text_len = np.max(text_lengths)
|
||||
max_text_len = np.max(text_lenghts)
|
||||
wav = [d['wav'] for d in batch]
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
|
|
|
@ -55,6 +55,7 @@ class AttentionWrapper(nn.Module):
|
|||
processed_memory=None, mask=None, memory_lengths=None):
|
||||
if processed_memory is None:
|
||||
processed_memory = memory
|
||||
|
||||
if memory_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(memory, memory_lengths)
|
||||
|
||||
|
@ -73,7 +74,7 @@ class AttentionWrapper(nn.Module):
|
|||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
# Normalize attention weight
|
||||
alignment = F.softmax(alignment, dim=-1) ## TODO: might be buggy
|
||||
alignment = F.softmax(alignment, dim=-1)
|
||||
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
|
|
File diff suppressed because one or more lines are too long
7
train.py
7
train.py
|
@ -123,7 +123,7 @@ def main(args):
|
|||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step)
|
||||
for params_group in optimizer.param_groups:
|
||||
param_group['lr'] = current_lr
|
||||
params_group['lr'] = current_lr
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -204,11 +204,14 @@ def main(args):
|
|||
checkpoint_path)
|
||||
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
||||
|
||||
# Log spectrogram reconstruction
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu()[None, :]
|
||||
gt_spec = linear_spec_var[0].data.cpu()[None, :]
|
||||
align_img = alignments[0].data.cpu().t()[None, :]
|
||||
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||
|
||||
|
||||
#lr_scheduler.step(loss.data[0])
|
||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||
|
|
Loading…
Reference in New Issue