2018-01-22 14:59:41 +00:00
|
|
|
# coding: utf-8
|
|
|
|
import torch
|
|
|
|
from torch.autograd import Variable
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
from .attention import BahdanauAttention, AttentionWrapper
|
|
|
|
from .attention import get_mask_from_lengths
|
|
|
|
|
|
|
|
class Prenet(nn.Module):
|
|
|
|
def __init__(self, in_dim, sizes=[256, 128]):
|
|
|
|
super(Prenet, self).__init__()
|
|
|
|
in_sizes = [in_dim] + sizes[:-1]
|
|
|
|
self.layers = nn.ModuleList(
|
|
|
|
[nn.Linear(in_size, out_size)
|
|
|
|
for (in_size, out_size) in zip(in_sizes, sizes)])
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.dropout = nn.Dropout(0.5)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
for linear in self.layers:
|
|
|
|
inputs = self.dropout(self.relu(linear(inputs)))
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
class BatchNormConv1d(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
|
|
|
|
activation=None):
|
|
|
|
super(BatchNormConv1d, self).__init__()
|
|
|
|
self.conv1d = nn.Conv1d(in_dim, out_dim,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride, padding=padding, bias=False)
|
|
|
|
# Following tensorflow's default parameters
|
|
|
|
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.conv1d(x)
|
|
|
|
if self.activation is not None:
|
|
|
|
x = self.activation(x)
|
|
|
|
return self.bn(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Highway(nn.Module):
|
|
|
|
def __init__(self, in_size, out_size):
|
|
|
|
super(Highway, self).__init__()
|
|
|
|
self.H = nn.Linear(in_size, out_size)
|
|
|
|
self.H.bias.data.zero_()
|
|
|
|
self.T = nn.Linear(in_size, out_size)
|
|
|
|
self.T.bias.data.fill_(-1)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
H = self.relu(self.H(inputs))
|
|
|
|
T = self.sigmoid(self.T(inputs))
|
|
|
|
return H * T + inputs * (1.0 - T)
|
|
|
|
|
|
|
|
|
|
|
|
class CBHG(nn.Module):
|
|
|
|
"""CBHG module: a recurrent neural network composed of:
|
|
|
|
- 1-d convolution banks
|
|
|
|
- Highway networks + residual connections
|
|
|
|
- Bidirectional gated recurrent units
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, in_dim, K=16, projections=[128, 128]):
|
|
|
|
super(CBHG, self).__init__()
|
|
|
|
self.in_dim = in_dim
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.conv1d_banks = nn.ModuleList(
|
|
|
|
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
|
|
|
|
padding=k // 2, activation=self.relu)
|
|
|
|
for k in range(1, K + 1)])
|
|
|
|
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
|
|
|
|
|
|
|
in_sizes = [K * in_dim] + projections[:-1]
|
|
|
|
activations = [self.relu] * (len(projections) - 1) + [None]
|
|
|
|
self.conv1d_projections = nn.ModuleList(
|
|
|
|
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
|
|
|
padding=1, activation=ac)
|
|
|
|
for (in_size, out_size, ac) in zip(
|
|
|
|
in_sizes, projections, activations)])
|
|
|
|
|
|
|
|
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
|
|
|
|
self.highways = nn.ModuleList(
|
|
|
|
[Highway(in_dim, in_dim) for _ in range(4)])
|
|
|
|
|
|
|
|
self.gru = nn.GRU(
|
|
|
|
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
|
|
|
|
2018-02-04 16:25:00 +00:00
|
|
|
def forward(self, inputs):
|
2018-01-22 14:59:41 +00:00
|
|
|
# (B, T_in, in_dim)
|
|
|
|
x = inputs
|
|
|
|
|
|
|
|
# Needed to perform conv1d on time-axis
|
|
|
|
# (B, in_dim, T_in)
|
|
|
|
if x.size(-1) == self.in_dim:
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
|
|
|
|
T = x.size(-1)
|
|
|
|
|
|
|
|
# (B, in_dim*K, T_in)
|
|
|
|
# Concat conv1d bank outputs
|
|
|
|
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
|
|
|
|
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
|
|
|
|
x = self.max_pool1d(x)[:, :, :T]
|
|
|
|
|
|
|
|
for conv1d in self.conv1d_projections:
|
|
|
|
x = conv1d(x)
|
|
|
|
|
|
|
|
# (B, T_in, in_dim)
|
|
|
|
# Back to the original shape
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
|
|
|
|
if x.size(-1) != self.in_dim:
|
|
|
|
x = self.pre_highway(x)
|
|
|
|
|
|
|
|
# Residual connection
|
|
|
|
x += inputs
|
|
|
|
for highway in self.highways:
|
|
|
|
x = highway(x)
|
|
|
|
|
2018-02-04 16:25:00 +00:00
|
|
|
# if input_lengths is not None:
|
|
|
|
# print(x.size())
|
|
|
|
# print(len(input_lengths))
|
|
|
|
# x = nn.utils.rnn.pack_padded_sequence(
|
|
|
|
# x, input_lengths.data.cpu().numpy(), batch_first=True)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# (B, T_in, in_dim*2)
|
|
|
|
self.gru.flatten_parameters()
|
|
|
|
outputs, _ = self.gru(x)
|
|
|
|
|
2018-02-04 16:25:00 +00:00
|
|
|
#if input_lengths is not None:
|
|
|
|
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
|
|
|
# outputs, batch_first=True)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
|
|
def __init__(self, in_dim):
|
|
|
|
super(Encoder, self).__init__()
|
|
|
|
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
|
|
|
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
|
|
|
|
2018-02-04 16:25:00 +00:00
|
|
|
def forward(self, inputs):
|
2018-01-22 14:59:41 +00:00
|
|
|
inputs = self.prenet(inputs)
|
2018-02-04 16:25:00 +00:00
|
|
|
return self.cbhg(inputs)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
def __init__(self, memory_dim, r):
|
|
|
|
super(Decoder, self).__init__()
|
2018-02-05 14:27:02 +00:00
|
|
|
self.max_decoder_steps = 200
|
2018-01-22 14:59:41 +00:00
|
|
|
self.memory_dim = memory_dim
|
|
|
|
self.r = r
|
2018-02-05 14:27:02 +00:00
|
|
|
# input -> |Linear| -> processed_inputs
|
|
|
|
self.input_layer = nn.Linear(256, 256, bias=False)
|
|
|
|
# memory -> |Prenet| -> processed_memory
|
2018-01-22 14:59:41 +00:00
|
|
|
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
2018-02-05 14:27:02 +00:00
|
|
|
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
2018-01-22 14:59:41 +00:00
|
|
|
self.attention_rnn = AttentionWrapper(
|
|
|
|
nn.GRUCell(256 + 128, 256),
|
|
|
|
BahdanauAttention(256)
|
|
|
|
)
|
2018-02-05 14:27:02 +00:00
|
|
|
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
|
2018-01-22 14:59:41 +00:00
|
|
|
self.project_to_decoder_in = nn.Linear(512, 256)
|
2018-02-05 14:27:02 +00:00
|
|
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
2018-01-22 14:59:41 +00:00
|
|
|
self.decoder_rnns = nn.ModuleList(
|
|
|
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
2018-02-05 14:27:02 +00:00
|
|
|
# RNN_state -> |Linear| -> mel_spec
|
2018-01-22 14:59:41 +00:00
|
|
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
|
|
|
|
2018-02-05 14:27:02 +00:00
|
|
|
def forward(self, inputs, memory=None, memory_lengths=None):
|
2018-01-22 14:59:41 +00:00
|
|
|
"""
|
|
|
|
Decoder forward step.
|
|
|
|
|
|
|
|
If decoder inputs are not given (e.g., at testing time), as noted in
|
|
|
|
Tacotron paper, greedy decoding is adapted.
|
|
|
|
|
|
|
|
Args:
|
2018-02-05 14:27:02 +00:00
|
|
|
inputs: Encoder outputs. (B, T_encoder, dim)
|
2018-01-22 14:59:41 +00:00
|
|
|
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
|
|
|
decoder outputs are used as decoder inputs.
|
|
|
|
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
|
|
|
attention masking.
|
|
|
|
"""
|
2018-02-05 14:27:02 +00:00
|
|
|
B = inputs.size(0)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
2018-02-05 14:27:02 +00:00
|
|
|
# TODO: take thi segment into Attention module.
|
|
|
|
processed_inputs = self.input_layer(inputs)
|
2018-01-22 14:59:41 +00:00
|
|
|
if memory_lengths is not None:
|
2018-02-05 14:27:02 +00:00
|
|
|
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
2018-01-22 14:59:41 +00:00
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
|
|
|
|
# Run greedy decoding if memory is None
|
|
|
|
greedy = memory is None
|
|
|
|
|
|
|
|
if memory is not None:
|
|
|
|
# Grouping multiple frames if necessary
|
|
|
|
if memory.size(-1) == self.memory_dim:
|
|
|
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
|
|
|
assert memory.size(-1) == self.memory_dim * self.r,\
|
|
|
|
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
|
|
|
self.memory_dim, self.r)
|
|
|
|
T_decoder = memory.size(1)
|
|
|
|
|
2018-02-05 14:27:02 +00:00
|
|
|
# go frame - 0 frames tarting the sequence
|
|
|
|
initial_memory = Variable(
|
|
|
|
inputs.data.new(B, self.memory_dim * self.r).zero_())
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# Init decoder states
|
|
|
|
attention_rnn_hidden = Variable(
|
2018-02-05 14:27:02 +00:00
|
|
|
inputs.data.new(B, 256).zero_())
|
2018-01-22 14:59:41 +00:00
|
|
|
decoder_rnn_hiddens = [Variable(
|
2018-02-05 14:27:02 +00:00
|
|
|
inputs.data.new(B, 256).zero_())
|
2018-01-22 14:59:41 +00:00
|
|
|
for _ in range(len(self.decoder_rnns))]
|
2018-02-05 14:27:02 +00:00
|
|
|
current_context_vec = Variable(
|
|
|
|
inputs.data.new(B, 256).zero_())
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# Time first (T_decoder, B, memory_dim)
|
|
|
|
if memory is not None:
|
|
|
|
memory = memory.transpose(0, 1)
|
|
|
|
|
|
|
|
outputs = []
|
|
|
|
alignments = []
|
|
|
|
|
|
|
|
t = 0
|
2018-02-05 14:27:02 +00:00
|
|
|
memory_input = initial_memory
|
2018-01-22 14:59:41 +00:00
|
|
|
while True:
|
|
|
|
if t > 0:
|
2018-02-05 14:27:02 +00:00
|
|
|
memory_input = outputs[-1] if greedy else memory[t - 1]
|
2018-01-22 14:59:41 +00:00
|
|
|
# Prenet
|
2018-02-05 14:27:02 +00:00
|
|
|
memory_input = self.prenet(memory_input)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# Attention RNN
|
2018-02-05 14:27:02 +00:00
|
|
|
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
|
|
|
memory_input, current_context_vec, attention_rnn_hidden,
|
|
|
|
inputs, processed_inputs=processed_inputs, mask=mask)
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# Concat RNN output and attention context vector
|
|
|
|
decoder_input = self.project_to_decoder_in(
|
2018-02-05 14:27:02 +00:00
|
|
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
2018-01-22 14:59:41 +00:00
|
|
|
|
|
|
|
# Pass through the decoder RNNs
|
|
|
|
for idx in range(len(self.decoder_rnns)):
|
|
|
|
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
|
|
|
decoder_input, decoder_rnn_hiddens[idx])
|
|
|
|
# Residual connectinon
|
|
|
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
|
|
|
|
|
|
|
output = decoder_input
|
|
|
|
|
|
|
|
# predict mel vectors from decoder vectors
|
|
|
|
output = self.proj_to_mel(output)
|
|
|
|
|
|
|
|
outputs += [output]
|
|
|
|
alignments += [alignment]
|
|
|
|
|
|
|
|
t += 1
|
|
|
|
|
|
|
|
if greedy:
|
|
|
|
if t > 1 and is_end_of_frames(output):
|
|
|
|
break
|
|
|
|
elif t > self.max_decoder_steps:
|
2018-02-05 14:27:02 +00:00
|
|
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
|
|
|
Something is probably wrong.")
|
2018-01-22 14:59:41 +00:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
if t >= T_decoder:
|
|
|
|
break
|
|
|
|
|
|
|
|
assert greedy or len(outputs) == T_decoder
|
|
|
|
|
|
|
|
# Back to batch first
|
|
|
|
alignments = torch.stack(alignments).transpose(0, 1)
|
|
|
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
|
|
|
|
|
|
|
return outputs, alignments
|
|
|
|
|
|
|
|
|
|
|
|
def is_end_of_frames(output, eps=0.2):
|
|
|
|
return (output.data <= eps).all()
|