TTS/layers/tacotron.py

324 lines
12 KiB
Python
Raw Normal View History

2018-01-22 14:59:41 +00:00
# coding: utf-8
import torch
from torch import nn
2018-03-07 14:58:51 +00:00
from .attention import AttentionRNN
2018-01-22 14:59:41 +00:00
from .attention import get_mask_from_lengths
2018-04-03 10:24:57 +00:00
2018-01-22 14:59:41 +00:00
class Prenet(nn.Module):
2018-02-08 18:10:11 +00:00
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
2018-04-03 10:24:57 +00:00
"""
2018-02-08 18:10:11 +00:00
def __init__(self, in_features, out_features=[256, 128]):
2018-01-22 14:59:41 +00:00
super(Prenet, self).__init__()
2018-02-08 18:10:11 +00:00
in_features = [in_features] + out_features[:-1]
2018-01-22 14:59:41 +00:00
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
2018-02-08 18:10:11 +00:00
for (in_size, out_size) in zip(in_features, out_features)])
2018-01-22 14:59:41 +00:00
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
class BatchNormConv1d(nn.Module):
2018-02-08 18:10:11 +00:00
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
function between Conv and BatchNorm layers. BatchNorm layer
is initialized with the TF default values for momentum and eps.
Args:
in_channels: size of each input sample
out_channels: size of each output samples
kernel_size: kernel size of conv filters
stride: stride of conv filters
padding: padding of conv filters
activation: activation function set b/w Conv1d and BatchNorm
Shapes:
- input: batch x dims
- output: batch x dims
"""
2018-03-19 17:38:47 +00:00
2018-02-08 18:10:11 +00:00
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
2018-01-22 14:59:41 +00:00
activation=None):
super(BatchNormConv1d, self).__init__()
2018-02-08 18:10:11 +00:00
self.conv1d = nn.Conv1d(in_channels, out_channels,
2018-01-22 14:59:41 +00:00
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
# Following tensorflow's default parameters
2018-02-08 18:10:11 +00:00
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
2018-01-22 14:59:41 +00:00
self.activation = activation
def forward(self, x):
2018-04-03 10:24:57 +00:00
x = self.conv1d(x)
2018-01-22 14:59:41 +00:00
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
class Highway(nn.Module):
def __init__(self, in_size, out_size):
super(Highway, self).__init__()
self.H = nn.Linear(in_size, out_size)
self.H.bias.data.zero_()
self.T = nn.Linear(in_size, out_size)
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
H = self.relu(self.H(inputs))
T = self.sigmoid(self.T(inputs))
return H * T + inputs * (1.0 - T)
class CBHG(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
2018-02-08 18:10:11 +00:00
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Shapes:
- input: batch x time x dim
- output: batch x time x dim*2
2018-01-22 14:59:41 +00:00
"""
2018-02-08 18:10:11 +00:00
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
2018-01-22 14:59:41 +00:00
super(CBHG, self).__init__()
2018-02-08 18:10:11 +00:00
self.in_features = in_features
2018-01-22 14:59:41 +00:00
self.relu = nn.ReLU()
2018-02-08 18:10:11 +00:00
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
2018-01-22 14:59:41 +00:00
self.conv1d_banks = nn.ModuleList(
2018-02-08 18:10:11 +00:00
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
2018-01-22 14:59:41 +00:00
padding=k // 2, activation=self.relu)
2018-02-08 18:10:11 +00:00
for k in range(1, K + 1)])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
2018-01-22 14:59:41 +00:00
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
2018-02-08 18:10:11 +00:00
out_features = [K * in_features] + projections[:-1]
2018-04-03 10:24:57 +00:00
activations = [self.relu] * (len(projections) - 1)
2018-02-08 18:10:11 +00:00
activations += [None]
2018-01-22 14:59:41 +00:00
2018-02-08 18:10:11 +00:00
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
2018-01-22 14:59:41 +00:00
self.highways = nn.ModuleList(
2018-02-08 18:10:11 +00:00
[Highway(in_features, in_features) for _ in range(num_highways)])
2018-01-22 14:59:41 +00:00
2018-02-08 18:10:11 +00:00
# bi-directional GPU layer
2018-01-22 14:59:41 +00:00
self.gru = nn.GRU(
2018-02-08 18:10:11 +00:00
in_features, in_features, 1, batch_first=True, bidirectional=True)
2018-01-22 14:59:41 +00:00
def forward(self, inputs):
2018-02-08 18:10:11 +00:00
# (B, T_in, in_features)
2018-01-22 14:59:41 +00:00
x = inputs
# Needed to perform conv1d on time-axis
2018-02-08 18:10:11 +00:00
# (B, in_features, T_in)
if x.size(-1) == self.in_features:
2018-01-22 14:59:41 +00:00
x = x.transpose(1, 2)
T = x.size(-1)
2018-02-08 18:10:11 +00:00
# (B, in_features*K, T_in)
2018-01-22 14:59:41 +00:00
# Concat conv1d bank outputs
2018-02-08 18:10:11 +00:00
outs = []
for conv1d in self.conv1d_banks:
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
2018-02-23 16:35:53 +00:00
2018-02-08 18:10:11 +00:00
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
2018-01-22 14:59:41 +00:00
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
2018-02-08 18:10:11 +00:00
# (B, T_in, in_features)
2018-01-22 14:59:41 +00:00
# Back to the original shape
x = x.transpose(1, 2)
2018-02-08 18:10:11 +00:00
if x.size(-1) != self.in_features:
2018-01-22 14:59:41 +00:00
x = self.pre_highway(x)
# Residual connection
2018-02-08 18:10:11 +00:00
# TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers
2018-01-22 14:59:41 +00:00
x += inputs
for highway in self.highways:
x = highway(x)
2018-02-08 18:10:11 +00:00
# (B, T_in, in_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
2018-04-03 10:24:57 +00:00
self.gru.flatten_parameters()
2018-01-22 14:59:41 +00:00
outputs, _ = self.gru(x)
return outputs
class Encoder(nn.Module):
2018-02-08 18:10:11 +00:00
r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features):
2018-01-22 14:59:41 +00:00
super(Encoder, self).__init__()
2018-02-08 18:10:11 +00:00
self.prenet = Prenet(in_features, out_features=[256, 128])
2018-01-22 14:59:41 +00:00
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs):
2018-02-13 09:45:52 +00:00
r"""
Args:
inputs (FloatTensor): embedding features
Shapes:
2018-02-13 16:08:23 +00:00
- inputs: batch x time x in_features
- outputs: batch x time x 128*2
2018-02-13 09:45:52 +00:00
"""
2018-01-22 14:59:41 +00:00
inputs = self.prenet(inputs)
return self.cbhg(inputs)
2018-01-22 14:59:41 +00:00
class Decoder(nn.Module):
2018-02-08 18:10:11 +00:00
r"""Decoder module.
Args:
2018-02-13 09:45:52 +00:00
in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
2018-02-13 16:08:23 +00:00
eps (float): threshold for detecting the end of a sentence.
2018-02-08 18:10:11 +00:00
"""
2018-04-03 10:24:57 +00:00
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
2018-01-22 14:59:41 +00:00
super(Decoder, self).__init__()
2018-03-22 19:34:16 +00:00
self.mode = mode
2018-02-05 14:27:02 +00:00
self.max_decoder_steps = 200
2018-01-22 14:59:41 +00:00
self.memory_dim = memory_dim
2018-02-13 16:08:23 +00:00
self.eps = eps
2018-01-22 14:59:41 +00:00
self.r = r
2018-02-05 14:27:02 +00:00
# memory -> |Prenet| -> processed_memory
2018-02-08 18:10:11 +00:00
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
2018-03-07 14:58:51 +00:00
self.attention_rnn = AttentionRNN(256, in_features, 128)
2018-02-08 18:10:11 +00:00
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
2018-02-05 14:27:02 +00:00
# decoder_RNN_input -> |RNN| -> RNN_state
2018-01-22 14:59:41 +00:00
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
2018-02-05 14:27:02 +00:00
# RNN_state -> |Linear| -> mel_spec
2018-01-22 14:59:41 +00:00
self.proj_to_mel = nn.Linear(256, memory_dim * r)
2018-03-19 15:26:16 +00:00
def forward(self, inputs, memory=None):
"""
2018-01-22 14:59:41 +00:00
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
2018-02-13 16:08:23 +00:00
inputs: Encoder outputs.
2018-03-07 14:58:51 +00:00
memory (None): Decoder memory (autoregression. If None (at eval-time),
2018-03-19 17:38:47 +00:00
decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
2018-02-08 18:10:11 +00:00
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
2018-01-22 14:59:41 +00:00
"""
2018-02-05 14:27:02 +00:00
B = inputs.size(0)
2018-01-22 14:59:41 +00:00
# Run greedy decoding if memory is None
2018-03-23 12:18:51 +00:00
greedy = not self.training
2018-01-22 14:59:41 +00:00
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1)
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
2018-04-03 10:24:57 +00:00
self.memory_dim, self.r)
2018-01-22 14:59:41 +00:00
T_decoder = memory.size(1)
2018-02-05 14:27:02 +00:00
# go frame - 0 frames tarting the sequence
2018-05-10 23:30:43 +00:00
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
2018-01-22 14:59:41 +00:00
# Init decoder states
2018-05-10 23:30:43 +00:00
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
2018-01-22 14:59:41 +00:00
for _ in range(len(self.decoder_rnns))]
2018-05-10 23:30:43 +00:00
current_context_vec = inputs.data.new(B, 256).zero_()
2018-01-22 14:59:41 +00:00
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
outputs = []
alignments = []
t = 0
2018-02-05 14:27:02 +00:00
memory_input = initial_memory
2018-01-22 14:59:41 +00:00
while True:
if t > 0:
2018-03-19 16:27:19 +00:00
if greedy:
memory_input = outputs[-1]
else:
2018-04-12 12:59:40 +00:00
memory_input = memory[t-1]
2018-01-22 14:59:41 +00:00
# Prenet
2018-02-08 18:10:11 +00:00
processed_memory = self.prenet(memory_input)
2018-01-22 14:59:41 +00:00
# Attention RNN
2018-02-05 14:27:02 +00:00
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
2018-01-22 14:59:41 +00:00
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
2018-02-05 14:27:02 +00:00
torch.cat((attention_rnn_hidden, current_context_vec), -1))
2018-01-22 14:59:41 +00:00
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connectinon
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(output)
outputs += [output]
alignments += [alignment]
t += 1
2018-03-22 19:34:16 +00:00
if (not greedy and self.training) or (greedy and memory is not None):
if t >= T_decoder:
break
else:
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
2018-01-22 14:59:41 +00:00
break
elif t > self.max_decoder_steps:
2018-02-05 14:27:02 +00:00
print(" !! Decoder stopped with 'max_decoder_steps'. \
Something is probably wrong.")
2018-01-22 14:59:41 +00:00
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
2018-03-22 19:47:54 +00:00
return outputs, alignments
2018-01-22 14:59:41 +00:00
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
return ((output.data <= eps).prod(0) > 0).any() \
and alignment.data[:, int(alignment.shape[1]/2):].sum() > 0.7