From e12bbc2a5ccfbccb91ef678d44b7e9f29f2641f2 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 22 Jan 2019 18:25:55 +0100 Subject: [PATCH] init decoder states with a function --- layers/tacotron.py | 59 ++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 4c171938..63f5d157 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -28,8 +28,7 @@ class Prenet(nn.Module): def init_layers(self): for layer in self.layers: torch.nn.init.xavier_uniform_( - layer.weight, - gain=torch.nn.init.calculate_gain('relu')) + layer.weight, gain=torch.nn.init.calculate_gain('relu')) def forward(self, inputs): for linear in self.layers: @@ -88,8 +87,7 @@ class BatchNormConv1d(nn.Module): else: raise RuntimeError('Unknown activation function') torch.nn.init.xavier_uniform_( - self.conv1d.weight, - gain=torch.nn.init.calculate_gain(w_gain)) + self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) def forward(self, x): x = self.padder(x) @@ -113,11 +111,9 @@ class Highway(nn.Module): def init_layers(self): torch.nn.init.xavier_uniform_( - self.H.weight, - gain=torch.nn.init.calculate_gain('relu')) + self.H.weight, gain=torch.nn.init.calculate_gain('relu')) torch.nn.init.xavier_uniform_( - self.T.weight, - gain=torch.nn.init.calculate_gain('sigmoid')) + self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid')) def forward(self, inputs): H = self.relu(self.H(inputs)) @@ -339,6 +335,9 @@ class Decoder(nn.Module): gain=torch.nn.init.calculate_gain('linear')) def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ B = memory.shape[0] # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: @@ -348,6 +347,27 @@ class Decoder(nn.Module): memory = memory.transpose(0, 1) return memory + def _init_states(self, inputs): + """ + Initialization of decoder states + """ + B = inputs.size(0) + T = inputs.size(1) + # go frame as zeros matrix + initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() + # decoder states + attention_rnn_hidden = inputs.data.new(B, 256).zero_() + decoder_rnn_hiddens = [ + inputs.data.new(B, 256).zero_() + for _ in range(len(self.decoder_rnns)) + ] + current_context_vec = inputs.data.new(B, self.in_features).zero_() + # attention states + attention = inputs.data.new(B, T).zero_() + attention_cum = inputs.data.new(B, T).zero_() + return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens, + current_context_vec, attention, attention_cum) + def forward(self, inputs, memory=None, mask=None): """ Decoder forward step. @@ -366,30 +386,17 @@ class Decoder(nn.Module): - inputs: batch x time x encoder_out_dim - memory: batch x #mel_specs x mel_spec_dim """ - B = inputs.size(0) - T = inputs.size(1) # Run greedy decoding if memory is None greedy = not self.training if memory is not None: memory = self._reshape_memory(memory) T_decoder = memory.size(0) - # go frame as zeros matrix - initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() - # decoder states - attention_rnn_hidden = inputs.data.new(B, 256).zero_() - decoder_rnn_hiddens = [ - inputs.data.new(B, 256).zero_() - for _ in range(len(self.decoder_rnns)) - ] - current_context_vec = inputs.data.new(B, self.in_features).zero_() - # attention states - attention = inputs.data.new(B, T).zero_() - attention_cum = inputs.data.new(B, T).zero_() outputs = [] attentions = [] stop_tokens = [] t = 0 - memory_input = initial_memory + memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\ + current_context_vec, attention, attention_cum = self._init_states(inputs) while True: if t > 0: if memory is None: @@ -434,7 +441,8 @@ class Decoder(nn.Module): if t >= T_decoder: break else: - if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or + attention[:, -1].item() > 0.6): break elif t > self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -460,8 +468,7 @@ class StopNet(nn.Module): self.linear = nn.Linear(in_features, 1) self.sigmoid = nn.Sigmoid() torch.nn.init.xavier_uniform_( - self.linear.weight, - gain=torch.nn.init.calculate_gain('linear')) + self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) def forward(self, inputs): outputs = self.dropout(inputs)