mirror of https://github.com/coqui-ai/TTS.git
534 lines
20 KiB
Python
534 lines
20 KiB
Python
from math import sqrt
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
init_gain='linear'):
|
|
super(Linear, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(
|
|
in_features, out_features, bias=bias)
|
|
self._init_w(init_gain)
|
|
|
|
def _init_w(self, init_gain):
|
|
torch.nn.init.xavier_uniform_(
|
|
self.linear_layer.weight,
|
|
gain=torch.nn.init.calculate_gain(init_gain))
|
|
|
|
def forward(self, x):
|
|
return self.linear_layer(x)
|
|
|
|
|
|
class LinearBN(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
bias=True,
|
|
init_gain='linear'):
|
|
super(LinearBN, self).__init__()
|
|
self.linear_layer = torch.nn.Linear(
|
|
in_features, out_features, bias=bias)
|
|
self.bn = nn.BatchNorm1d(out_features)
|
|
self._init_w(init_gain)
|
|
|
|
def _init_w(self, init_gain):
|
|
torch.nn.init.xavier_uniform_(
|
|
self.linear_layer.weight,
|
|
gain=torch.nn.init.calculate_gain(init_gain))
|
|
|
|
def forward(self, x):
|
|
out = self.linear_layer(x)
|
|
if len(out.shape)==3:
|
|
out = out.permute(1, 2, 0)
|
|
out = self.bn(out)
|
|
if len(out.shape) == 3:
|
|
out = out.permute(2, 0, 1)
|
|
return out
|
|
|
|
|
|
class Prenet(nn.Module):
|
|
def __init__(self, in_features, prenet_type, out_features=[256, 256]):
|
|
super(Prenet, self).__init__()
|
|
self.prenet_type = prenet_type
|
|
in_features = [in_features] + out_features[:-1]
|
|
if prenet_type == "bn":
|
|
self.layers = nn.ModuleList([
|
|
LinearBN(in_size, out_size, bias=False)
|
|
for (in_size, out_size) in zip(in_features, out_features)
|
|
])
|
|
elif prenet_type == "original":
|
|
self.layers = nn.ModuleList(
|
|
[Linear(in_size, out_size, bias=False)
|
|
for (in_size, out_size) in zip(in_features, out_features)
|
|
])
|
|
|
|
def forward(self, x):
|
|
for linear in self.layers:
|
|
if self.prenet_type == "original":
|
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
|
elif self.prenet_type == "bn":
|
|
x = F.relu(linear(x))
|
|
return x
|
|
|
|
|
|
class ConvBNBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
|
|
super(ConvBNBlock, self).__init__()
|
|
assert (kernel_size - 1) % 2 == 0
|
|
padding = (kernel_size - 1) // 2
|
|
conv1d = nn.Conv1d(
|
|
in_channels, out_channels, kernel_size, padding=padding)
|
|
norm = nn.BatchNorm1d(out_channels)
|
|
dropout = nn.Dropout(p=0.5)
|
|
if nonlinear == 'relu':
|
|
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
|
|
elif nonlinear == 'tanh':
|
|
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
|
|
else:
|
|
self.net = nn.Sequential(conv1d, norm, dropout)
|
|
|
|
def forward(self, x):
|
|
output = self.net(x)
|
|
return output
|
|
|
|
|
|
class LocationLayer(nn.Module):
|
|
def __init__(self, attention_n_filters, attention_kernel_size,
|
|
attention_dim):
|
|
super(LocationLayer, self).__init__()
|
|
self.location_conv = nn.Conv1d(
|
|
in_channels=2,
|
|
out_channels=attention_n_filters,
|
|
kernel_size=31,
|
|
stride=1,
|
|
padding=(31 - 1) // 2,
|
|
bias=False)
|
|
self.location_dense = Linear(
|
|
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
|
|
|
def forward(self, attention_cat):
|
|
processed_attention = self.location_conv(attention_cat)
|
|
processed_attention = self.location_dense(
|
|
processed_attention.transpose(1, 2))
|
|
return processed_attention
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
|
attention_location_n_filters, attention_location_kernel_size,
|
|
windowing, norm, forward_attn, trans_agent):
|
|
super(Attention, self).__init__()
|
|
self.query_layer = Linear(
|
|
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
|
self.inputs_layer = Linear(
|
|
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
|
self.v = Linear(attention_dim, 1, bias=True)
|
|
if trans_agent:
|
|
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
|
|
self.location_layer = LocationLayer(attention_location_n_filters,
|
|
attention_location_kernel_size,
|
|
attention_dim)
|
|
self._mask_value = -float("inf")
|
|
self.windowing = windowing
|
|
self.win_idx = None
|
|
self.norm = norm
|
|
self.forward_attn = forward_attn
|
|
self.trans_agent = trans_agent
|
|
|
|
def init_win_idx(self):
|
|
self.win_idx = -1
|
|
self.win_back = 2
|
|
self.win_front = 6
|
|
|
|
def init_forward_attn_state(self, inputs):
|
|
"""
|
|
Init forward attention states
|
|
"""
|
|
B = inputs.shape[0]
|
|
T = inputs.shape[1]
|
|
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
|
|
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
|
|
|
def get_attention(self, query, processed_inputs, attention_cat):
|
|
processed_query = self.query_layer(query.unsqueeze(1))
|
|
processed_attention_weights = self.location_layer(attention_cat)
|
|
energies = self.v(
|
|
torch.tanh(processed_query + processed_attention_weights +
|
|
processed_inputs))
|
|
|
|
energies = energies.squeeze(-1)
|
|
return energies, processed_query
|
|
|
|
def apply_windowing(self, attention, inputs):
|
|
back_win = self.win_idx - self.win_back
|
|
front_win = self.win_idx + self.win_front
|
|
if back_win > 0:
|
|
attention[:, :back_win] = -float("inf")
|
|
if front_win < inputs.shape[1]:
|
|
attention[:, front_win:] = -float("inf")
|
|
# this is a trick to solve a special problem.
|
|
# but it does not hurt.
|
|
if self.win_idx == -1:
|
|
attention[:, 0] = attention.max()
|
|
# Update the window
|
|
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
|
return attention
|
|
|
|
def apply_forward_attention(self, inputs, alignment, processed_query):
|
|
# forward attention
|
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
|
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
|
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
|
# compute context
|
|
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
|
context = context.squeeze(1)
|
|
# compute transition agent
|
|
if self.trans_agent:
|
|
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
|
self.u = torch.sigmoid(self.ta(ta_input))
|
|
return context, self.alpha, alignment
|
|
|
|
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
|
attention_cat, mask):
|
|
attention, processed_query = self.get_attention(
|
|
attention_hidden_state, processed_inputs, attention_cat)
|
|
|
|
# apply masking
|
|
if mask is not None:
|
|
attention.data.masked_fill_(1 - mask, self._mask_value)
|
|
# apply windowing - only in eval mode
|
|
if not self.training and self.windowing:
|
|
attention = self.apply_windowing(attention, inputs)
|
|
# normalize attention values
|
|
if self.norm == "softmax":
|
|
alignment = torch.softmax(attention, dim=-1)
|
|
elif self.norm == "sigmoid":
|
|
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
|
attention).sum(dim=1).unsqueeze(1)
|
|
else:
|
|
raise RuntimeError("Unknown value for attention norm type")
|
|
# apply forward attention if enabled
|
|
if self.forward_attn:
|
|
return self.apply_forward_attention(inputs, alignment, processed_query)
|
|
else:
|
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
|
context = context.squeeze(1)
|
|
return context, alignment, alignment
|
|
|
|
|
|
class Postnet(nn.Module):
|
|
def __init__(self, mel_dim, num_convs=5):
|
|
super(Postnet, self).__init__()
|
|
self.convolutions = nn.ModuleList()
|
|
self.convolutions.append(
|
|
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
|
for i in range(1, num_convs - 1):
|
|
self.convolutions.append(
|
|
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
|
self.convolutions.append(
|
|
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
|
|
|
|
def forward(self, x):
|
|
for layer in self.convolutions:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_features=512):
|
|
super(Encoder, self).__init__()
|
|
convolutions = []
|
|
for _ in range(3):
|
|
convolutions.append(
|
|
ConvBNBlock(in_features, in_features, 5, 'relu'))
|
|
self.convolutions = nn.Sequential(*convolutions)
|
|
self.lstm = nn.LSTM(
|
|
in_features,
|
|
int(in_features / 2),
|
|
num_layers=1,
|
|
batch_first=True,
|
|
bidirectional=True)
|
|
self.rnn_state = None
|
|
|
|
def forward(self, x, input_lengths):
|
|
x = self.convolutions(x)
|
|
x = x.transpose(1, 2)
|
|
input_lengths = input_lengths.cpu().numpy()
|
|
x = nn.utils.rnn.pack_padded_sequence(
|
|
x, input_lengths, batch_first=True)
|
|
self.lstm.flatten_parameters()
|
|
outputs, _ = self.lstm(x)
|
|
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
|
outputs,
|
|
batch_first=True,
|
|
)
|
|
return outputs
|
|
|
|
def inference(self, x):
|
|
x = self.convolutions(x)
|
|
x = x.transpose(1, 2)
|
|
self.lstm.flatten_parameters()
|
|
outputs, _ = self.lstm(x)
|
|
return outputs
|
|
|
|
def inference_truncated(self, x):
|
|
"""
|
|
Preserve encoder state for continuous inference
|
|
"""
|
|
x = self.convolutions(x)
|
|
x = x.transpose(1, 2)
|
|
self.lstm.flatten_parameters()
|
|
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
|
return outputs
|
|
|
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
|
class Decoder(nn.Module):
|
|
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
|
|
super(Decoder, self).__init__()
|
|
self.mel_channels = inputs_dim
|
|
self.r = r
|
|
self.encoder_embedding_dim = in_features
|
|
self.attention_rnn_dim = 1024
|
|
self.decoder_rnn_dim = 1024
|
|
self.prenet_dim = 256
|
|
self.max_decoder_steps = 1000
|
|
self.gate_threshold = 0.5
|
|
self.p_attention_dropout = 0.1
|
|
self.p_decoder_dropout = 0.1
|
|
|
|
self.prenet = Prenet(self.mel_channels * r, prenet_type,
|
|
[self.prenet_dim, self.prenet_dim])
|
|
|
|
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
|
self.attention_rnn_dim)
|
|
|
|
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
|
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
|
|
|
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
|
self.decoder_rnn_dim, 1)
|
|
|
|
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
|
self.mel_channels * r)
|
|
|
|
self.stopnet = nn.Sequential(
|
|
nn.Dropout(0.1),
|
|
Linear(self.decoder_rnn_dim + self.mel_channels * r,
|
|
1,
|
|
bias=True,
|
|
init_gain='sigmoid'))
|
|
|
|
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
|
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
|
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
|
self.memory_truncated = None
|
|
|
|
def get_go_frame(self, inputs):
|
|
B = inputs.size(0)
|
|
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
|
return memory
|
|
|
|
def _init_states(self, inputs, mask, keep_states=False):
|
|
B = inputs.size(0)
|
|
T = inputs.size(1)
|
|
|
|
if not keep_states:
|
|
self.attention_hidden = self.attention_rnn_init(
|
|
inputs.data.new_zeros(B).long())
|
|
self.attention_cell = Variable(
|
|
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
|
|
|
self.decoder_hidden = self.decoder_rnn_inits(
|
|
inputs.data.new_zeros(B).long())
|
|
self.decoder_cell = Variable(
|
|
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
|
|
|
self.context = Variable(
|
|
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
|
|
|
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
|
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
|
|
|
self.inputs = inputs
|
|
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
|
self.mask = mask
|
|
|
|
def _reshape_memory(self, memories):
|
|
memories = memories.view(
|
|
memories.size(0), int(memories.size(1) / self.r), -1)
|
|
memories = memories.transpose(0, 1)
|
|
return memories
|
|
|
|
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
|
alignments = torch.stack(alignments).transpose(0, 1)
|
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
|
stop_tokens = stop_tokens.contiguous()
|
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
|
outputs = outputs.view(
|
|
outputs.size(0), -1, self.mel_channels)
|
|
outputs = outputs.transpose(1, 2)
|
|
return outputs, stop_tokens, alignments
|
|
|
|
def decode(self, memory):
|
|
cell_input = torch.cat((memory, self.context), -1)
|
|
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
|
cell_input, (self.attention_hidden, self.attention_cell))
|
|
self.attention_hidden = F.dropout(
|
|
self.attention_hidden, self.p_attention_dropout, self.training)
|
|
self.attention_cell = F.dropout(
|
|
self.attention_cell, self.p_attention_dropout, self.training)
|
|
|
|
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
|
self.attention_weights_cum.unsqueeze(1)),
|
|
dim=1)
|
|
self.context, self.attention_weights, alignments = self.attention_layer(
|
|
self.attention_hidden, self.inputs, self.processed_inputs,
|
|
attention_cat, self.mask)
|
|
|
|
self.attention_weights_cum += alignments
|
|
memory = torch.cat(
|
|
(self.attention_hidden, self.context), -1)
|
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
|
memory, (self.decoder_hidden, self.decoder_cell))
|
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
|
self.p_decoder_dropout, self.training)
|
|
self.decoder_cell = F.dropout(self.decoder_cell,
|
|
self.p_decoder_dropout, self.training)
|
|
|
|
decoder_hidden_context = torch.cat(
|
|
(self.decoder_hidden, self.context), dim=1)
|
|
|
|
decoder_output = self.linear_projection(
|
|
decoder_hidden_context)
|
|
|
|
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
|
|
|
gate_prediction = self.stopnet(stopnet_input)
|
|
return decoder_output, gate_prediction, self.attention_weights
|
|
|
|
def forward(self, inputs, memories, mask):
|
|
memory = self.get_go_frame(inputs).unsqueeze(0)
|
|
memories = self._reshape_memory(memories)
|
|
memories = torch.cat((memory, memories), dim=0)
|
|
memories = self.prenet(memories)
|
|
|
|
self._init_states(inputs, mask=mask)
|
|
if self.attention_layer.forward_attn:
|
|
self.attention_layer.init_forward_attn_state(inputs)
|
|
|
|
outputs, stop_tokens, alignments = [], [], []
|
|
while len(outputs) < memories.size(0) - 1:
|
|
memory = memories[len(outputs)]
|
|
mel_output, stop_token, attention_weights = self.decode(
|
|
memory)
|
|
outputs += [mel_output.squeeze(1)]
|
|
stop_tokens += [stop_token.squeeze(1)]
|
|
alignments += [attention_weights]
|
|
|
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
|
outputs, stop_tokens, alignments)
|
|
|
|
return outputs, stop_tokens, alignments
|
|
|
|
def inference(self, inputs):
|
|
memory = self.get_go_frame(inputs)
|
|
self._init_states(inputs, mask=None)
|
|
|
|
self.attention_layer.init_win_idx()
|
|
if self.attention_layer.forward_attn:
|
|
self.attention_layer.init_forward_attn_state(inputs)
|
|
|
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
|
stop_flags = [False, False, False]
|
|
stop_count = 0
|
|
while True:
|
|
memory = self.prenet(memory)
|
|
mel_output, stop_token, alignment = self.decode(memory)
|
|
stop_token = torch.sigmoid(stop_token.data)
|
|
outputs += [mel_output.squeeze(1)]
|
|
stop_tokens += [stop_token]
|
|
alignments += [alignment]
|
|
|
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
|
|
stop_flags[2] = t > inputs.shape[1] * 2
|
|
if all(stop_flags):
|
|
stop_count += 1
|
|
if stop_count > 2:
|
|
break
|
|
elif len(outputs) == self.max_decoder_steps:
|
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
|
break
|
|
|
|
memory = mel_output
|
|
t += 1
|
|
|
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
|
outputs, stop_tokens, alignments)
|
|
|
|
return outputs, stop_tokens, alignments
|
|
|
|
def inference_truncated(self, inputs):
|
|
"""
|
|
Preserve decoder states for continuous inference
|
|
"""
|
|
if self.memory_truncated is None:
|
|
self.memory_truncated = self.get_go_frame(inputs)
|
|
self._init_states(inputs, mask=None, keep_states=False)
|
|
else:
|
|
self._init_states(inputs, mask=None, keep_states=True)
|
|
|
|
self.attention_layer.init_win_idx()
|
|
if self.attention_layer.forward_attn:
|
|
self.attention_layer.init_forward_attn_state(inputs)
|
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
|
stop_flags = [False, False, False]
|
|
stop_count = 0
|
|
while True:
|
|
memory = self.prenet(self.memory_truncated)
|
|
mel_output, stop_token, alignment = self.decode(memory)
|
|
stop_token = torch.sigmoid(stop_token.data)
|
|
outputs += [mel_output.squeeze(1)]
|
|
stop_tokens += [stop_token]
|
|
alignments += [alignment]
|
|
|
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
|
stop_flags[2] = t > inputs.shape[1] * 2
|
|
if all(stop_flags):
|
|
stop_count += 1
|
|
if stop_count > 2:
|
|
break
|
|
elif len(outputs) == self.max_decoder_steps:
|
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
|
break
|
|
|
|
self.memory_truncated = mel_output
|
|
t += 1
|
|
|
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
|
outputs, stop_tokens, alignments)
|
|
|
|
return outputs, stop_tokens, alignments
|
|
|
|
|
|
def inference_step(self, inputs, t, memory=None):
|
|
"""
|
|
For debug purposes
|
|
"""
|
|
if t == 0:
|
|
memory = self.get_go_frame(inputs)
|
|
self._init_states(inputs, mask=None)
|
|
|
|
memory = self.prenet(memory)
|
|
mel_output, stop_token, alignment = self.decode(memory)
|
|
stop_token = torch.sigmoid(stop_token.data)
|
|
memory = mel_output
|
|
return mel_output, stop_token, alignment
|