mirror of https://github.com/coqui-ai/TTS.git
init decoder states with a function
parent
66f8d0e260
commit
e12bbc2a5c
|
@ -28,8 +28,7 @@ class Prenet(nn.Module):
|
|||
def init_layers(self):
|
||||
for layer in self.layers:
|
||||
torch.nn.init.xavier_uniform_(
|
||||
layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
layer.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
|
||||
def forward(self, inputs):
|
||||
for linear in self.layers:
|
||||
|
@ -88,8 +87,7 @@ class BatchNormConv1d(nn.Module):
|
|||
else:
|
||||
raise RuntimeError('Unknown activation function')
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv1d.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_gain))
|
||||
self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.padder(x)
|
||||
|
@ -113,11 +111,9 @@ class Highway(nn.Module):
|
|||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.H.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
self.H.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.T.weight,
|
||||
gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
|
||||
def forward(self, inputs):
|
||||
H = self.relu(self.H(inputs))
|
||||
|
@ -339,6 +335,9 @@ class Decoder(nn.Module):
|
|||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def _reshape_memory(self, memory):
|
||||
"""
|
||||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
B = memory.shape[0]
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
|
@ -348,6 +347,27 @@ class Decoder(nn.Module):
|
|||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs):
|
||||
"""
|
||||
Initialization of decoder states
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [
|
||||
inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
|
||||
current_context_vec, attention, attention_cum)
|
||||
|
||||
def forward(self, inputs, memory=None, mask=None):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
@ -366,30 +386,17 @@ class Decoder(nn.Module):
|
|||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# Run greedy decoding if memory is None
|
||||
greedy = not self.training
|
||||
if memory is not None:
|
||||
memory = self._reshape_memory(memory)
|
||||
T_decoder = memory.size(0)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [
|
||||
inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
outputs = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
memory_input = initial_memory
|
||||
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\
|
||||
current_context_vec, attention, attention_cum = self._init_states(inputs)
|
||||
while True:
|
||||
if t > 0:
|
||||
if memory is None:
|
||||
|
@ -434,7 +441,8 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or
|
||||
attention[:, -1].item() > 0.6):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -460,8 +468,7 @@ class StopNet(nn.Module):
|
|||
self.linear = nn.Linear(in_features, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
|
|
Loading…
Reference in New Issue