mirror of https://github.com/coqui-ai/TTS.git
add location attn to decoder
parent
243204bc3e
commit
288a6b5b1d
|
@ -2,8 +2,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .attention import AttentionRNN
|
from .attention import AttentionRNN
|
||||||
from .attention import get_mask_from_lengths
|
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
|
@ -270,8 +268,8 @@ class Decoder(nn.Module):
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
attention_vec = memory.data.new(B, T).zero_()
|
attention_vec = inputs.data.new(B, T).zero_()
|
||||||
attention_vec_cum = memory.data.new(B, T).zero_()
|
attention_vec_cum = inputs.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
|
@ -290,12 +288,11 @@ class Decoder(nn.Module):
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
attention_vec_cat = torch.cat((attention_vec.unsqueeze(1),
|
||||||
attention_vec_cum.unsqueeze(1)),
|
attention_vec_cum.unsqueeze(1) / (t + 1)),
|
||||||
dim=1)
|
dim=1)
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_vec_cat)
|
||||||
attention_vec_cum += attention_vec
|
attention_vec_cum += attention_vec
|
||||||
attention_vec_cum /= (t + 1)
|
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
|
Loading…
Reference in New Issue