mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev-tacotron2-optional-loc' into dev-tacotron2
commit
ad39810c5a
16
.compute
16
.compute
|
@ -1,12 +1,14 @@
|
|||
#!/bin/bash
|
||||
ls ${SHARED_DIR}/data/mozilla/Judy/
|
||||
yes | apt-get install sox
|
||||
yes | apt-get install ffmpeg
|
||||
soxi /data/ro/shared/data/mozilla/Judy/batch6/wavs_no_processing/6_126.wav
|
||||
pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
|
||||
yes | apt-get install espeak
|
||||
yes | apt-get install tmux
|
||||
yes | apt-get install zsh
|
||||
pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
|
||||
# wget https://www.dropbox.com/s/m8waow6b3ydpf6h/MozillaDataset.tar.gz?dl=0 -O /data/rw/home/mozilla.tar
|
||||
wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh
|
||||
sudo sh install.sh
|
||||
python3 setup.py develop
|
||||
# wget https://www.dropbox.com/s/evaouukiwb7krz8/MozillaDataset.tar.gz?dl=0 -O ${USER_DIR}/MozillaDataset.tar.gz
|
||||
# tar -xzvf ${USER_DIR}/MozillaDataset.tar.gz --no-same-owner -C ${USER_DIR}
|
||||
# python3 distribute.py --config_path config_cluster.json --data_path ${USER_DIR}/MozillaDataset/Mozilla/ --restore_path ${USER_DIR}/best_model_4583.pth.tar
|
||||
python3 distribute.py --config_path config_cluster.json --data_path ${SHARED_DIR}/data/mozilla/Judy/
|
||||
python3 distribute.py --config_path config_cluster.json --data_path ${USER_DIR}/MozillaAll2/Mozilla/ --restore_path ${USER_DIR}/checkpoint_123000_4761.pth.tar
|
||||
# python3 distribute.py --config_path config_cluster.json --data_path ${SHARED_DIR}/data/mozilla/Judy/
|
||||
# while true; do sleep 1000000; done
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "mozilla-nomask-fattn-bn",
|
||||
"run_description": "Finetune 4702 orignal -> bn prenet - Mozilla with prenet bn, no mask, batch group size 0",
|
||||
"run_name": "mozilla-fattn",
|
||||
"run_description": "Finetune 4761 with BN + Dropout. It is to compare to 4780 and see how dropout behaves with BN.",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -41,12 +41,14 @@
|
|||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"loss_masking": false, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
|
||||
"batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
|
@ -59,13 +61,13 @@
|
|||
"run_eval": true,
|
||||
"test_delay_epochs": 1, //Until attention is aligned, testing only wastes computation time.
|
||||
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata.txt", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
|
|
|
@ -53,9 +53,10 @@ class LinearBN(nn.Module):
|
|||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_features, prenet_type, out_features=[256, 256]):
|
||||
def __init__(self, in_features, prenet_type, prenet_dropout, out_features=[256, 256]):
|
||||
super(Prenet, self).__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
|
@ -70,9 +71,9 @@ class Prenet(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
if self.prenet_type == "original":
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
elif self.prenet_type == "bn":
|
||||
else:
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
|
||||
|
@ -120,7 +121,7 @@ class LocationLayer(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention,
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
windowing, norm, forward_attn, trans_agent):
|
||||
super(Attention, self).__init__()
|
||||
|
@ -130,38 +131,65 @@ class Attention(nn.Module):
|
|||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.location_attention = location_attention
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = -1
|
||||
self.win_back = 2
|
||||
self.win_front = 6
|
||||
|
||||
def init_forward_attn_state(self, inputs):
|
||||
"""
|
||||
Init forward attention states
|
||||
"""
|
||||
def init_forward_attn(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def get_attention(self, query, processed_inputs, attention_cat):
|
||||
def init_location_attention(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
|
||||
def init_states(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
if self.location_attention:
|
||||
self.init_location_attention(inputs)
|
||||
if self.forward_attn:
|
||||
self.init_forward_attn(inputs)
|
||||
if self.windowing:
|
||||
self.init_win_idx()
|
||||
|
||||
def update_location_attention(self, alignments):
|
||||
self.attention_weights_cum += alignments
|
||||
|
||||
def get_location_attention(self, query, processed_inputs):
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights +
|
||||
processed_inputs))
|
||||
processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
def get_attention(self, query, processed_inputs):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query +processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
|
@ -180,7 +208,7 @@ class Attention(nn.Module):
|
|||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment
|
||||
|
@ -190,15 +218,18 @@ class Attention(nn.Module):
|
|||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha, alignment
|
||||
return context, self.alpha
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||
attention_cat, mask):
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs, attention_cat)
|
||||
|
||||
mask):
|
||||
if self.location_attention:
|
||||
attention, processed_query = self.get_location_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
else:
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
|
@ -213,13 +244,16 @@ class Attention(nn.Module):
|
|||
attention).sum(dim=1).unsqueeze(1)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
if self.location_attention:
|
||||
self.update_location_attention(alignment)
|
||||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
return self.apply_forward_attention(inputs, alignment, processed_query)
|
||||
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alignment, alignment
|
||||
self.attention_weights = alignment
|
||||
return context
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
|
@ -289,7 +323,7 @@ class Encoder(nn.Module):
|
|||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.r = r
|
||||
|
@ -302,14 +336,14 @@ class Decoder(nn.Module):
|
|||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(self.mel_channels * r, prenet_type,
|
||||
self.prenet = Prenet(self.mel_channels * r, prenet_type, prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim])
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn,
|
||||
32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
@ -351,9 +385,6 @@ class Decoder(nn.Module):
|
|||
|
||||
self.context = Variable(
|
||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
|
@ -384,14 +415,10 @@ class Decoder(nn.Module):
|
|||
self.attention_cell = F.dropout(
|
||||
self.attention_cell, self.p_attention_dropout, self.training)
|
||||
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
self.context, self.attention_weights, alignments = self.attention_layer(
|
||||
self.context = self.attention_layer(
|
||||
self.attention_hidden, self.inputs, self.processed_inputs,
|
||||
attention_cat, self.mask)
|
||||
self.mask)
|
||||
|
||||
self.attention_weights_cum += alignments
|
||||
memory = torch.cat(
|
||||
(self.attention_hidden, self.context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
|
@ -410,7 +437,7 @@ class Decoder(nn.Module):
|
|||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
|
||||
gate_prediction = self.stopnet(stopnet_input)
|
||||
return decoder_output, gate_prediction, self.attention_weights
|
||||
return decoder_output, gate_prediction, self.attention_layer.attention_weights
|
||||
|
||||
def forward(self, inputs, memories, mask):
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
|
@ -419,8 +446,7 @@ class Decoder(nn.Module):
|
|||
memories = self.prenet(memories)
|
||||
|
||||
self._init_states(inputs, mask=mask)
|
||||
if self.attention_layer.forward_attn:
|
||||
self.attention_layer.init_forward_attn_state(inputs)
|
||||
self.attention_layer.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
|
@ -441,8 +467,7 @@ class Decoder(nn.Module):
|
|||
self._init_states(inputs, mask=None)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
if self.attention_layer.forward_attn:
|
||||
self.attention_layer.init_forward_attn_state(inputs)
|
||||
self.attention_layer.init_states(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False, False]
|
||||
|
@ -460,7 +485,7 @@ class Decoder(nn.Module):
|
|||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
stop_count += 1
|
||||
if stop_count > 2:
|
||||
if stop_count > 5:
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -485,8 +510,7 @@ class Decoder(nn.Module):
|
|||
self._init_states(inputs, mask=None, keep_states=True)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
if self.attention_layer.forward_attn:
|
||||
self.attention_layer.init_forward_attn_state(inputs)
|
||||
self.attention_layer.init_states(inputs)
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False, False]
|
||||
stop_count = 0
|
||||
|
|
|
@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
|
|||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, location_attn=True):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.n_mel_channels = 80
|
||||
self.n_frames_per_step = r
|
||||
|
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
|
|||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
|
||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
|
|
@ -262,6 +262,8 @@ def setup_model(num_chars, c):
|
|||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent)
|
||||
trans_agent=c.transition_agent,
|
||||
location_attn=c.location_attn)
|
||||
return model
|
|
@ -30,14 +30,14 @@ def plot_spectrogram(linear_output, audio):
|
|||
return fig
|
||||
|
||||
|
||||
def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None):
|
||||
def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None):
|
||||
if spectrogram is not None:
|
||||
num_plot = 4
|
||||
else:
|
||||
num_plot = 3
|
||||
|
||||
label_fontsize = 16
|
||||
plt.figure(figsize=(8, 24))
|
||||
fig = plt.figure(figsize=(8, 24))
|
||||
|
||||
plt.subplot(num_plot, 1, 1)
|
||||
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
|
||||
|
@ -46,6 +46,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
|||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
|
||||
text = sequence_to_phoneme(seq)
|
||||
print(text)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
||||
|
@ -69,3 +70,8 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
|||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
plt.tight_layout()
|
||||
plt.colorbar()
|
||||
|
||||
if output_path:
|
||||
print(output_path)
|
||||
fig.savefig(output_path)
|
||||
plt.close()
|
||||
|
|
Loading…
Reference in New Issue