diff --git a/layers/attention.py b/layers/attention.py index 5e468cfb..31cd03b6 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -39,7 +39,7 @@ class LocationSensitiveAttention(nn.Module): self.kernel_size = kernel_size self.filters = filters padding = int((kernel_size - 1) / 2) - self.loc_conv = nn.Conv1d(1, filters, + self.loc_conv = nn.Conv1d(2, filters, kernel_size=kernel_size, stride=1, padding=padding, bias=False) self.loc_linear = nn.Linear(filters, attn_dim) @@ -77,15 +77,15 @@ class AttentionRNNCell(nn.Module): out_dim (int): context vector feature dimension. rnn_dim (int): rnn hidden state dimension. annot_dim (int): annotation vector feature dimension. - memory_dim (int): memory vector (decoder autogression) feature dimension. + memory_dim (int): memory vector (decoder output) feature dimension. align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. """ super(AttentionRNNCell, self).__init__() self.align_model = align_model - self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim) + self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim) # pick bahdanau or location sensitive attention if align_model == 'b': - self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim) + self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim) if align_model == 'ls': self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim) else: diff --git a/layers/tacotron.py b/layers/tacotron.py index e021cd07..8c545c90 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -198,6 +198,7 @@ class Decoder(nn.Module): def __init__(self, in_features, memory_dim, r): super(Decoder, self).__init__() self.r = r + self.in_features = in_features self.max_decoder_steps = 200 self.memory_dim = memory_dim # memory -> |Prenet| -> processed_memory @@ -249,7 +250,7 @@ class Decoder(nn.Module): attention_rnn_hidden = inputs.data.new(B, 256).zero_() decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() for _ in range(len(self.decoder_rnns))] - current_context_vec = inputs.data.new(B, 128).zero_() + current_context_vec = inputs.data.new(B, self.in_features).zero_() stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() # attention states attention = inputs.data.new(B, T).zero_()