init decoder states with a function

pull/10/head
Eren Golge 2019-01-22 18:25:55 +01:00
parent 66f8d0e260
commit e12bbc2a5c
1 changed files with 33 additions and 26 deletions

View File

@ -28,8 +28,7 @@ class Prenet(nn.Module):
def init_layers(self):
for layer in self.layers:
torch.nn.init.xavier_uniform_(
layer.weight,
gain=torch.nn.init.calculate_gain('relu'))
layer.weight, gain=torch.nn.init.calculate_gain('relu'))
def forward(self, inputs):
for linear in self.layers:
@ -88,8 +87,7 @@ class BatchNormConv1d(nn.Module):
else:
raise RuntimeError('Unknown activation function')
torch.nn.init.xavier_uniform_(
self.conv1d.weight,
gain=torch.nn.init.calculate_gain(w_gain))
self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
def forward(self, x):
x = self.padder(x)
@ -113,11 +111,9 @@ class Highway(nn.Module):
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.H.weight,
gain=torch.nn.init.calculate_gain('relu'))
self.H.weight, gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.xavier_uniform_(
self.T.weight,
gain=torch.nn.init.calculate_gain('sigmoid'))
self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid'))
def forward(self, inputs):
H = self.relu(self.H(inputs))
@ -339,6 +335,9 @@ class Decoder(nn.Module):
gain=torch.nn.init.calculate_gain('linear'))
def _reshape_memory(self, memory):
"""
Reshape the spectrograms for given 'r'
"""
B = memory.shape[0]
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
@ -348,6 +347,27 @@ class Decoder(nn.Module):
memory = memory.transpose(0, 1)
return memory
def _init_states(self, inputs):
"""
Initialization of decoder states
"""
B = inputs.size(0)
T = inputs.size(1)
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [
inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
current_context_vec, attention, attention_cum)
def forward(self, inputs, memory=None, mask=None):
"""
Decoder forward step.
@ -366,30 +386,17 @@ class Decoder(nn.Module):
- inputs: batch x time x encoder_out_dim
- memory: batch x #mel_specs x mel_spec_dim
"""
B = inputs.size(0)
T = inputs.size(1)
# Run greedy decoding if memory is None
greedy = not self.training
if memory is not None:
memory = self._reshape_memory(memory)
T_decoder = memory.size(0)
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [
inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
outputs = []
attentions = []
stop_tokens = []
t = 0
memory_input = initial_memory
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\
current_context_vec, attention, attention_cum = self._init_states(inputs)
while True:
if t > 0:
if memory is None:
@ -434,7 +441,8 @@ class Decoder(nn.Module):
if t >= T_decoder:
break
else:
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or
attention[:, -1].item() > 0.6):
break
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
@ -460,8 +468,7 @@ class StopNet(nn.Module):
self.linear = nn.Linear(in_features, 1)
self.sigmoid = nn.Sigmoid()
torch.nn.init.xavier_uniform_(
self.linear.weight,
gain=torch.nn.init.calculate_gain('linear'))
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
def forward(self, inputs):
outputs = self.dropout(inputs)