mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'tf-convert2' into dev
commit
df8fd3823d
30
config.json
30
config.json
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"model": "Tacotron2",
|
||||
"model": "Tacotron2",
|
||||
"run_name": "ljspeech",
|
||||
"run_description": "tacotron2",
|
||||
|
||||
|
@ -11,12 +11,12 @@
|
|||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
@ -26,7 +26,7 @@
|
|||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
|
||||
|
@ -50,7 +50,7 @@
|
|||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
|
@ -61,8 +61,8 @@
|
|||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"eval_batch_size":16,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||
|
@ -80,11 +80,11 @@
|
|||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
|
||||
// TACOTRON PRENET
|
||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||
"prenet_type": "original", // "original" or "bn".
|
||||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
|
||||
// ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
|
@ -98,16 +98,16 @@
|
|||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
|
||||
// STOPNET
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log traning on console.
|
||||
"print_eval": false, // If True, it prints loss values in evalulation.
|
||||
"print_eval": false, // If True, it prints loss values in evalulation.
|
||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
|
@ -119,7 +119,7 @@
|
|||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
|
|
|
@ -33,7 +33,7 @@ class LinearBN(nn.Module):
|
|||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_features)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
|
@ -45,7 +45,7 @@ class LinearBN(nn.Module):
|
|||
out = self.linear_layer(x)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(1, 2, 0)
|
||||
out = self.bn(out)
|
||||
out = self.batch_normalization(out)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(2, 0, 1)
|
||||
return out
|
||||
|
@ -63,18 +63,18 @@ class Prenet(nn.Module):
|
|||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
self.linear_layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "original":
|
||||
self.layers = nn.ModuleList([
|
||||
self.linear_layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
else:
|
||||
|
@ -93,7 +93,7 @@ class LocationLayer(nn.Module):
|
|||
attention_n_filters=32,
|
||||
attention_kernel_size=31):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
self.location_conv1d = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=attention_kernel_size,
|
||||
|
@ -104,7 +104,7 @@ class LocationLayer(nn.Module):
|
|||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
||||
def forward(self, attention_cat):
|
||||
processed_attention = self.location_conv(attention_cat)
|
||||
processed_attention = self.location_conv1d(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
|
|
@ -6,130 +6,128 @@ from .common_layers import init_attn, Prenet, Linear
|
|||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
|
||||
super(ConvBNBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv1d = nn.Conv1d(in_channels,
|
||||
self.convolution1d = nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
norm = nn.BatchNorm1d(out_channels)
|
||||
dropout = nn.Dropout(p=0.5)
|
||||
if nonlinear == 'relu':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
|
||||
elif nonlinear == 'tanh':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU()
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
else:
|
||||
self.net = nn.Sequential(conv1d, norm, dropout)
|
||||
self.activation = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
output = self.net(x)
|
||||
return output
|
||||
o = self.convolution1d(x)
|
||||
o = self.batch_normalization(o)
|
||||
o = self.activation(o)
|
||||
o = self.dropout(o)
|
||||
return o
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, mel_dim, num_convs=5):
|
||||
def __init__(self, output_dim, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
||||
ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
|
||||
for _ in range(1, num_convs - 1):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
||||
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
|
||||
ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
|
||||
|
||||
def forward(self, x):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
x = layer(x)
|
||||
return x
|
||||
o = layer(o)
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_features=512):
|
||||
def __init__(self, output_input_dim=512):
|
||||
super(Encoder, self).__init__()
|
||||
convolutions = []
|
||||
self.convolutions = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
convolutions.append(
|
||||
ConvBNBlock(in_features, in_features, 5, 'relu'))
|
||||
self.convolutions = nn.Sequential(*convolutions)
|
||||
self.lstm = nn.LSTM(in_features,
|
||||
int(in_features / 2),
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
|
||||
self.lstm = nn.LSTM(output_input_dim,
|
||||
int(output_input_dim / 2),
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bias=True,
|
||||
bidirectional=True)
|
||||
self.rnn_state = None
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = nn.utils.rnn.pack_padded_sequence(x,
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o)
|
||||
o = o.transpose(1, 2)
|
||||
o = nn.utils.rnn.pack_padded_sequence(o,
|
||||
input_lengths,
|
||||
batch_first=True)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
outputs,
|
||||
batch_first=True,
|
||||
)
|
||||
return outputs
|
||||
o, _ = self.lstm(o)
|
||||
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
|
||||
return o
|
||||
|
||||
def inference(self, x):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
return outputs
|
||||
|
||||
def inference_truncated(self, x):
|
||||
"""
|
||||
Preserve encoder state for continuous inference
|
||||
"""
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
||||
return outputs
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o)
|
||||
o = o.transpose(1, 2)
|
||||
# self.lstm.flatten_parameters()
|
||||
o, _ = self.lstm(o)
|
||||
return o
|
||||
|
||||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm,
|
||||
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, attn_K, separate_stopnet,
|
||||
speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.memory_dim = memory_dim
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.encoder_embedding_dim = in_features
|
||||
self.encoder_embedding_dim = input_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = 1000
|
||||
self.gate_threshold = 0.5
|
||||
|
||||
# model dimensions
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.max_decoder_steps = 1000
|
||||
self.gate_threshold = 0.5
|
||||
self.attn_dim = 128
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = self.memory_dim
|
||||
self.prenet = Prenet(
|
||||
prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
prenet_dim = self.frame_dim
|
||||
self.prenet = Prenet(prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
|
||||
self.query_dim,
|
||||
bias=True)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
embedding_dim=input_dim,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
|
@ -141,15 +139,16 @@ class Decoder(nn.Module):
|
|||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
|
||||
self.decoder_rnn_dim,
|
||||
bias=True)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||
self.memory_dim * self.r_init)
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
|
||||
self.frame_dim * self.r_init)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
|
||||
Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
@ -161,7 +160,7 @@ class Decoder(nn.Module):
|
|||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.memory_dim * self.r)
|
||||
self.frame_dim * self.r)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
|
@ -187,9 +186,9 @@ class Decoder(nn.Module):
|
|||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
if memory.size(-1) == self.frame_dim:
|
||||
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
# Time first (T_decoder, B, frame_dim)
|
||||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
|
@ -197,22 +196,22 @@ class Decoder(nn.Module):
|
|||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(outputs.size(0), -1, self.memory_dim)
|
||||
outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def _update_memory(self, memory):
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.memory_dim * (self.r - 1):]
|
||||
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||
return memory[:, self.frame_dim * (self.r - 1):]
|
||||
return memory[:, :, self.frame_dim * (self.r - 1):]
|
||||
|
||||
def decode(self, memory):
|
||||
'''
|
||||
shapes:
|
||||
- memory: B x r * self.memory_dim
|
||||
- memory: B x r * self.frame_dim
|
||||
'''
|
||||
# self.context: B x D_en
|
||||
# query_input: B x D_en + (r * self.memory_dim)
|
||||
# query_input: B x D_en + (r * self.frame_dim)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
|
@ -235,16 +234,16 @@ class Decoder(nn.Module):
|
|||
# B x (D_decoder_rnn + D_en)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||
dim=1)
|
||||
# B x (self.r * self.memory_dim)
|
||||
# B x (self.r * self.frame_dim)
|
||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||
# B x (D_decoder_rnn + (self.r * self.memory_dim))
|
||||
# B x (D_decoder_rnn + (self.r * self.frame_dim))
|
||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
# select outputs for the reduction rate self.r
|
||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||
decoder_output = decoder_output[:, :self.r * self.frame_dim]
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||
|
|
|
@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
|
|||
super(Tacotron2, self).__init__()
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.n_frames_per_step = r
|
||||
self.r = r
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 512
|
||||
encoder_dim = 512 if num_speakers > 1 else 512
|
||||
|
|
|
@ -6,7 +6,8 @@ import torch as T
|
|||
from TTS.server.synthesizer import Synthesizer
|
||||
from TTS.tests import get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.io import load_config, save_checkpoint
|
||||
|
||||
|
||||
class DemoServerTest(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.datasets import TTSDataset
|
||||
from TTS.datasets.preprocess import ljspeech
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import MSELossMasked
|
||||
from TTS.models.tacotron2 import Tacotron2
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import copy
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import MSELossMasked
|
||||
from TTS.tf.models.tacotron2 import Tacotron2
|
||||
|
||||
#pylint: disable=unused-variable
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TacotronTFTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
''' test forward pass '''
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
input = tf.convert_to_tensor(input.cpu().numpy())
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy())
|
||||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
|
||||
# training pass
|
||||
output = model(input, input_lengths, mel_spec, training=True)
|
||||
|
||||
# check model output shapes
|
||||
assert np.all(output[0].shape == mel_spec.shape)
|
||||
assert np.all(output[1].shape == mel_spec.shape)
|
||||
assert output[2].shape[2] == input.shape[1]
|
||||
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
||||
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
||||
|
||||
# inference pass
|
||||
output = model(input, training=False)
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import L1LossMasked
|
||||
from TTS.models.tacotron import Tacotron
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
import unittest
|
||||
from TTS.utils.text import *
|
||||
from TTS.tests import get_tests_path
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
|
||||
|
@ -92,4 +92,4 @@ def test_text2phone():
|
|||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||
lang = "en-us"
|
||||
ph = text2phone(text, lang)
|
||||
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
|
||||
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
## Utilities to Convert Models to Tensorflow2
|
||||
Here there are utilities to convert trained Torch models to Tensorflow (2.2>=).
|
||||
|
||||
We currently support Tacotron2 with Location Sensitive Attention.
|
||||
|
||||
Be aware that our old Torch models may not work with this module due to additional changes in layer naming convention. Therefore, you need to train new models or handle these changes.
|
||||
|
||||
We do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome.
|
||||
|
||||
To see how you can use TF model at inference, check the notebook.
|
||||
|
||||
This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own.
|
|
@ -0,0 +1,196 @@
|
|||
# %%
|
||||
import sys
|
||||
sys.path.append('/home/erogol/Projects')
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
# %%
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
from fuzzywuzzy import fuzz
|
||||
|
||||
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.generic_utils import setup_model, count_parameters
|
||||
from TTS.utils.io import load_config
|
||||
from TTS_tf.models.tacotron2 import Tacotron2
|
||||
from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
|
||||
from TTS_tf.utils.generic_utils import save_checkpoint
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--torch_model_path',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument(
|
||||
'--output_path',
|
||||
type=str,
|
||||
help='path to save TF model weights.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
config_path = args.config_path
|
||||
c = load_config(config_path)
|
||||
num_speakers = 0
|
||||
|
||||
# init torch model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device('cpu'))
|
||||
state_dict = checkpoint['model']
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# init tf model
|
||||
model_tf = Tacotron2(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
|
||||
# set initial layer mapping - these are not captured by the below heuristic approach
|
||||
# TODO: set layer names so that we can remove these manual matching
|
||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||
var_map = [
|
||||
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'),
|
||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
||||
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
||||
('decoder/linear_projection/kernel:0', 'decoder.linear_projection.linear_layer.weight'),
|
||||
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
|
||||
]
|
||||
|
||||
|
||||
# %%
|
||||
# get tf_model graph
|
||||
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
|
||||
mel_pred = model_tf(input_ids, training=False)
|
||||
|
||||
# get tf variables
|
||||
tf_vars = model_tf.weights
|
||||
|
||||
# match variable names with fuzzy logic
|
||||
torch_var_names = list(state_dict.keys())
|
||||
tf_var_names = [we.name for we in model_tf.weights]
|
||||
for tf_name in tf_var_names:
|
||||
# skip re-mapped layer names
|
||||
if tf_name in [name[0] for name in var_map]:
|
||||
continue
|
||||
tf_name_edited = convert_tf_name(tf_name)
|
||||
ratios = [fuzz.ratio(torch_name, tf_name_edited) for torch_name in torch_var_names]
|
||||
max_idx = np.argmax(ratios)
|
||||
matching_name = torch_var_names[max_idx]
|
||||
del torch_var_names[max_idx]
|
||||
var_map.append((tf_name, matching_name))
|
||||
|
||||
|
||||
# %%
|
||||
# print variable match
|
||||
from pprint import pprint
|
||||
pprint(var_map)
|
||||
pprint(torch_var_names)
|
||||
|
||||
# pass weights
|
||||
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
||||
|
||||
# Compare TF and TORCH models
|
||||
# %%
|
||||
# check embedding outputs
|
||||
model.eval()
|
||||
input_ids = torch.randint(0, 24, (1, 128)).long()
|
||||
|
||||
o_t = model.embedding(input_ids)
|
||||
o_tf = model_tf.embedding(input_ids.detach().numpy())
|
||||
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
|
||||
|
||||
# compare encoder outputs
|
||||
oo_en = model.encoder.inference(o_t.transpose(1,2))
|
||||
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
|
||||
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
|
||||
|
||||
# compare decoder.attention_rnn
|
||||
inp = torch.rand([1, 768])
|
||||
inp_tf = inp.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None)
|
||||
output, cell_state = model.decoder.attention_rnn(inp)
|
||||
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
|
||||
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
# compare decoder.attention
|
||||
query = output
|
||||
inputs = torch.rand([1, 128, 512])
|
||||
query_tf = query.detach().numpy()
|
||||
inputs_tf = inputs.numpy()
|
||||
|
||||
model.decoder.attention.init_states(inputs)
|
||||
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
|
||||
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
|
||||
context = model.decoder.attention(query, inputs, processes_inputs, None)
|
||||
|
||||
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
|
||||
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf)
|
||||
context_tf = model_tf.decoder.attention(query_tf, training=False)
|
||||
|
||||
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(context, context_tf) < 1e-5
|
||||
|
||||
# compare decoder.decoder_rnn
|
||||
input = torch.rand([1, 1536])
|
||||
input_tf = input.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None)
|
||||
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
|
||||
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
|
||||
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
|
||||
assert abs(input - input_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
# compare decoder.linear_projection
|
||||
input = torch.rand([1, 1536])
|
||||
input_tf = input.numpy()
|
||||
output = model.decoder.linear_projection(input)
|
||||
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
|
||||
assert compare_torch_tf(output, output_tf) < 1e-5
|
||||
|
||||
# compare decoder outputs
|
||||
model.decoder.max_decoder_steps = 100
|
||||
model_tf.decoder.set_max_decoder_steps(100)
|
||||
output, align, stop = model.decoder.inference(oo_en)
|
||||
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
|
||||
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
|
||||
assert compare_torch_tf(output.transpose(1,2), output_tf) < 1e-4
|
||||
|
||||
# compare the whole model output
|
||||
outputs_torch = model.inference(input_ids)
|
||||
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
|
||||
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean() )
|
||||
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
|
||||
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
|
||||
|
||||
# %%
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
|
||||
checkpoint['r'], args.output_path)
|
||||
print(' > Model conversion is successfully completed :).')
|
||||
|
|
@ -0,0 +1,258 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.python.ops import math_ops
|
||||
# from tensorflow_addons.seq2seq import BahdanauAttention
|
||||
|
||||
from TTS.tf.utils.tf_utils import shape_list
|
||||
|
||||
|
||||
class Linear(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(Linear, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
return self.activation(self.linear_layer(x))
|
||||
|
||||
|
||||
class LinearBN(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(LinearBN, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
out = self.linear_layer(x)
|
||||
out = self.batch_normalization(out, training=training)
|
||||
return self.activation(out)
|
||||
|
||||
|
||||
class Prenet(keras.layers.Layer):
|
||||
def __init__(self,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
units,
|
||||
bias,
|
||||
**kwargs):
|
||||
super(Prenet, self).__init__(**kwargs)
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.linear_layers = []
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
elif prenet_type == "original":
|
||||
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
else:
|
||||
raise RuntimeError(' [!] Unknown prenet type.')
|
||||
if prenet_dropout:
|
||||
self.dropout = keras.layers.Dropout(rate=0.5)
|
||||
|
||||
def call(self, x, training=None):
|
||||
"""
|
||||
shapes:
|
||||
x: B x T x C
|
||||
"""
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = self.dropout(linear(x), training=training)
|
||||
else:
|
||||
x = linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def _sigmoid_norm(score):
|
||||
attn_weights = tf.nn.sigmoid(score)
|
||||
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
|
||||
return attn_weights
|
||||
|
||||
|
||||
class Attention(keras.layers.Layer):
|
||||
"""TODO: implement forward_attention"""
|
||||
"""TODO: location sensitive attention"""
|
||||
"""TODO: implement attention windowing """
|
||||
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
|
||||
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
|
||||
use_trans_agent, use_forward_attn_mask, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
self.use_loc_attn = use_loc_attn
|
||||
self.loc_attn_n_filters = loc_attn_n_filters
|
||||
self.loc_attn_kernel_size = loc_attn_kernel_size
|
||||
self.use_windowing = use_windowing
|
||||
self.norm = norm
|
||||
self.use_forward_attn = use_forward_attn
|
||||
self.use_trans_agent = use_trans_agent
|
||||
self.use_forward_attn_mask = use_forward_attn_mask
|
||||
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
|
||||
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
|
||||
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
|
||||
if use_loc_attn:
|
||||
self.location_conv1d = keras.layers.Conv1D(
|
||||
filters=loc_attn_n_filters,
|
||||
kernel_size=loc_attn_kernel_size,
|
||||
padding='same',
|
||||
use_bias=False,
|
||||
name='location_layer/location_conv1d')
|
||||
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
|
||||
if norm == 'softmax':
|
||||
self.norm_func = tf.nn.softmax
|
||||
elif norm == 'sigmoid':
|
||||
self.norm_func = _sigmoid_norm
|
||||
else:
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
|
||||
def init_states(self, batch_size, value_length):
|
||||
states = ()
|
||||
if self.use_loc_attn:
|
||||
attention_cum = tf.zeros([batch_size, value_length])
|
||||
attention_old = tf.zeros([batch_size, value_length])
|
||||
states = (attention_cum, attention_old)
|
||||
return states
|
||||
|
||||
def process_values(self, values):
|
||||
""" cache values for decoder iterations """
|
||||
self.processed_values = self.inputs_layer(values)
|
||||
self.values = values
|
||||
|
||||
def get_loc_attn(self, query, states):
|
||||
""" compute location attention, query layer and
|
||||
unnorm. attention weights"""
|
||||
attention_cum, attention_old = states
|
||||
attn_cat = tf.stack([attention_old, attention_cum],
|
||||
axis=2)
|
||||
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
|
||||
score = self.v(
|
||||
tf.nn.tanh(self.processed_values + processed_query +
|
||||
processed_attn))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def get_attn(self, query):
|
||||
""" compute query layer and unnormalized attention weights """
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def apply_score_masking(self, score, mask):
|
||||
""" ignore sequence paddings """
|
||||
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# Bias so padding positions do not contribute to attention distribution.
|
||||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
return score
|
||||
|
||||
def call(self, query, states):
|
||||
"""
|
||||
shapes:
|
||||
query: B x D
|
||||
"""
|
||||
if self.use_loc_attn:
|
||||
score, processed_query = self.get_loc_attn(query, states)
|
||||
else:
|
||||
score, processed_query = self.get_attn(query)
|
||||
|
||||
# TODO: masking
|
||||
# if mask is not None:
|
||||
# self.apply_score_masking(score, mask)
|
||||
# attn_weights shape == (batch_size, max_length, 1)
|
||||
|
||||
attn_weights = self.norm_func(score)
|
||||
|
||||
# update attention states
|
||||
if self.use_loc_attn:
|
||||
states = (states[0] + attn_weights, attn_weights)
|
||||
else:
|
||||
states = ()
|
||||
|
||||
# context_vector shape after sum == (batch_size, hidden_size)
|
||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||
context_vector = tf.squeeze(context_vector, axis=1)
|
||||
return context_vector, attn_weights, states
|
||||
|
||||
|
||||
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
|
||||
# dtype = processed_query.dtype
|
||||
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
|
||||
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
|
||||
|
||||
|
||||
# class LocationSensitiveAttention(BahdanauAttention):
|
||||
# def __init__(self,
|
||||
# units,
|
||||
# memory=None,
|
||||
# memory_sequence_length=None,
|
||||
# normalize=False,
|
||||
# probability_fn="softmax",
|
||||
# kernel_initializer="glorot_uniform",
|
||||
# dtype=None,
|
||||
# name="LocationSensitiveAttention",
|
||||
# location_attention_filters=32,
|
||||
# location_attention_kernel_size=31):
|
||||
|
||||
# super(LocationSensitiveAttention,
|
||||
# self).__init__(units=units,
|
||||
# memory=memory,
|
||||
# memory_sequence_length=memory_sequence_length,
|
||||
# normalize=normalize,
|
||||
# probability_fn='softmax', ## parent module default
|
||||
# kernel_initializer=kernel_initializer,
|
||||
# dtype=dtype,
|
||||
# name=name)
|
||||
# if probability_fn == 'sigmoid':
|
||||
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
|
||||
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
|
||||
# self.location_dense = keras.layers.Dense(units, use_bias=False)
|
||||
# # self.v = keras.layers.Dense(1, use_bias=True)
|
||||
|
||||
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
|
||||
# processed_query = tf.expand_dims(processed_query, 1)
|
||||
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
|
||||
|
||||
# def _location_sensitive(self, alignment_cum, alignment_old):
|
||||
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
|
||||
# return self.location_dense(self.location_conv(alignment_cat))
|
||||
|
||||
# def _sigmoid_normalization(self, score):
|
||||
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
|
||||
|
||||
# # def _apply_masking(self, score, mask):
|
||||
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# # # Bias so padding positions do not contribute to attention distribution.
|
||||
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
# # return score
|
||||
|
||||
# def _calculate_attention(self, query, state):
|
||||
# alignment_cum, alignment_old = state[:2]
|
||||
# processed_query = self.query_layer(
|
||||
# query) if self.query_layer else query
|
||||
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
|
||||
# score = self._location_sensitive_score(
|
||||
# processed_query,
|
||||
# self.keys,
|
||||
# processed_loc)
|
||||
# alignment = self.probability_fn(score, state)
|
||||
# alignment_cum = alignment_cum + alignment
|
||||
# state[0] = alignment_cum
|
||||
# state[1] = alignment
|
||||
# return alignment, state
|
||||
|
||||
# def compute_context(self, alignments):
|
||||
# expanded_alignments = tf.expand_dims(alignments, 1)
|
||||
# context = tf.matmul(expanded_alignments, self.values)
|
||||
# context = tf.squeeze(context, [1])
|
||||
# return context
|
||||
|
||||
# # def call(self, query, state):
|
||||
# # alignment, next_state = self._calculate_attention(query, state)
|
||||
# # return alignment, next_state
|
|
@ -0,0 +1,231 @@
|
|||
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from TTS.tf.utils.tf_utils import shape_list
|
||||
from TTS.tf.layers.common_layers import Prenet, Attention
|
||||
# from tensorflow_addons.seq2seq import AttentionWrapper
|
||||
|
||||
|
||||
class ConvBNBlock(keras.layers.Layer):
|
||||
def __init__(self, filters, kernel_size, activation, **kwargs):
|
||||
super(ConvBNBlock, self).__init__(**kwargs)
|
||||
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
|
||||
self.activation = keras.layers.Activation(activation, name='activation')
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = self.convolution1d(x)
|
||||
o = self.batch_normalization(o, training=training)
|
||||
o = self.activation(o)
|
||||
o = self.dropout(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Postnet(keras.layers.Layer):
|
||||
def __init__(self, output_filters, num_convs, **kwargs):
|
||||
super(Postnet, self).__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
|
||||
for idx in range(1, num_convs - 1):
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
|
||||
self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(keras.layers.Layer):
|
||||
def __init__(self, output_input_dim, **kwargs):
|
||||
super(Encoder, self).__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
for idx in range(3):
|
||||
self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
|
||||
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o, training=training)
|
||||
o = self.lstm(o)
|
||||
return o
|
||||
|
||||
|
||||
class Decoder(keras.layers.Layer):
|
||||
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
||||
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
|
||||
super(Decoder, self).__init__(**kwargs)
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||
self.r = tf.constant(r, dtype=tf.int32)
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
||||
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
||||
|
||||
# model dimensions
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.attn_dim = 128
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(prenet_type,
|
||||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim],
|
||||
bias=False,
|
||||
name='prenet')
|
||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
|
||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
|
||||
# TODO: implement other attn options
|
||||
self.attention = Attention(attn_dim=self.attn_dim,
|
||||
use_loc_attn=True,
|
||||
loc_attn_n_filters=32,
|
||||
loc_attn_kernel_size=31,
|
||||
use_windowing=False,
|
||||
norm=attn_norm,
|
||||
use_forward_attn=use_forward_attn,
|
||||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name='attention')
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
|
||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
|
||||
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
|
||||
|
||||
|
||||
def set_max_decoder_steps(self, new_max_steps):
|
||||
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
|
||||
|
||||
def set_r(self, new_r):
|
||||
self.r = tf.constant(new_r, dtype=tf.int32)
|
||||
|
||||
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
||||
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
||||
zero_context = tf.zeros([batch_size, memory_dim])
|
||||
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
|
||||
attention_states = self.attention.init_states(batch_size, memory_length)
|
||||
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
|
||||
|
||||
def step(self, prenet_next, states,
|
||||
memory_seq_length=None, training=None):
|
||||
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
|
||||
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
|
||||
attention_rnn_output, attention_rnn_state = \
|
||||
self.attention_rnn(attention_rnn_input,
|
||||
attention_rnn_state, training=training)
|
||||
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
|
||||
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
|
||||
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
|
||||
decoder_rnn_output, decoder_rnn_state = \
|
||||
self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
|
||||
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
|
||||
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
|
||||
output_frame = self.linear_projection(linear_projection_input, training=training)
|
||||
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
|
||||
stopnet_output = self.stopnet(stopnet_input, training=training)
|
||||
output_frame = output_frame[:, :self.r * self.frame_dim]
|
||||
states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
|
||||
return output_frame, stopnet_output, states, attention
|
||||
|
||||
def decode(self, memory, states, frames, memory_seq_length=None):
|
||||
B, T, D = shape_list(memory)
|
||||
num_iter = shape_list(frames)[1] // self.r
|
||||
# init states
|
||||
frame_zero = tf.expand_dims(states[0], 1)
|
||||
frames = tf.concat([frame_zero, frames], axis=1)
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
prenet_output = self.prenet(frames, training=True)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
|
||||
prenet_next = prenet_output[:, step]
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
memory_seq_length)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
|
||||
_, memory, _, states, outputs, stop_tokens, attentions = \
|
||||
tf.while_loop(lambda *arg: True,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, prenet_output, states, outputs,
|
||||
stop_tokens, attentions),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=num_iter)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0 ,2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode_inference(self, memory, states):
|
||||
B, T, D = shape_list(memory)
|
||||
# init states
|
||||
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||
# pre-computes
|
||||
self.attention.process_values(memory)
|
||||
|
||||
# iter vars
|
||||
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||
step_count = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_tokens, attentions, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
stop_tokens = stop_tokens.stack()
|
||||
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
attentions = tf.transpose(attentions, [1, 0, 2])
|
||||
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
|
||||
stop_tokens = tf.squeeze(stop_tokens, axis=2)
|
||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||
if training:
|
||||
return self.decode(memory, states, frames, memory_seq_length)
|
||||
return self.decode_inference(memory, states)
|
|
@ -0,0 +1,72 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
from TTS.tf.utils.tf_utils import shape_list
|
||||
|
||||
|
||||
class Tacotron2(keras.models.Model):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
attn_K=4,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.num_speakers = num_speakers
|
||||
self.speaker_embed_dim = 256
|
||||
|
||||
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
||||
self.encoder = Encoder(512, name='encoder')
|
||||
# TODO: most of the decoder args have no use at the momment
|
||||
self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type,
|
||||
prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask,
|
||||
use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
if training == True:
|
||||
return self.training(characters, text_lengths, frames)
|
||||
else:
|
||||
return self.inference(characters)
|
||||
|
||||
def training(self, characters, text_lengths, frames):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=True)
|
||||
encoder_output = self.encoder(embedding_vectors, training=True)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
|
||||
postnet_frames = self.postnet(decoder_frames, training=True)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def inference(self, characters):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=False)
|
||||
encoder_output = self.encoder(embedding_vectors, training=False)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
||||
postnet_frames = self.postnet(decoder_frames, training=False)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,708 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"This is to test TTS models with benchmark sentences for speech synthesis.\n",
|
||||
"\n",
|
||||
"Before running this script please DON'T FORGET: \n",
|
||||
"- to set file paths.\n",
|
||||
"- to download related model files.\n",
|
||||
"- download or clone related repos, linked below.\n",
|
||||
"- setup the repositories. ```python setup.py install```\n",
|
||||
"- to checkout right commit versions (given next to the model in the models page).\n",
|
||||
"- to set the file paths below.\n",
|
||||
"\n",
|
||||
"Repositories:\n",
|
||||
"- TTS: https://github.com/mozilla/TTS"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false",
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# you may need to change this depending on your system\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"import io\n",
|
||||
"import torch \n",
|
||||
"import tensorflow as tf\n",
|
||||
"print(tf.config.list_physical_devices('GPU'))\n",
|
||||
"\n",
|
||||
"import time\n",
|
||||
"import json\n",
|
||||
"import yaml\n",
|
||||
"import numpy as np\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"plt.rcParams[\"figure.figsize\"] = (16,5)\n",
|
||||
"\n",
|
||||
"import librosa\n",
|
||||
"import librosa.display\n",
|
||||
"\n",
|
||||
"from TTS.tf.models.tacotron2 import Tacotron2\n",
|
||||
"from TTS.tf.utils.generic_utils import setup_model, load_checkpoint\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.io import load_config\n",
|
||||
"from TTS.utils.synthesis import synthesis\n",
|
||||
"from TTS.utils.visual import visualize\n",
|
||||
"\n",
|
||||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"\n",
|
||||
"%matplotlib agg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
|
||||
" t_1 = time.time()\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, None, None, False, CONFIG.enable_eos_bos_chars, use_gl, backend=BACKEND)\n",
|
||||
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
||||
" # coorect the normalization differences b/w TTS and the Vocoder.\n",
|
||||
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
||||
" print(mel_postnet_spec.shape)\n",
|
||||
" print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n",
|
||||
" if not use_gl:\n",
|
||||
" waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n",
|
||||
" mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T\n",
|
||||
" if use_cuda and not use_gl:\n",
|
||||
" waveform = waveform.cpu()\n",
|
||||
" waveform = waveform.numpy()\n",
|
||||
" waveform = waveform.squeeze()\n",
|
||||
" rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n",
|
||||
" print(waveform.shape)\n",
|
||||
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
|
||||
" print(\" > Real-time factor: {}\".format(rtf))\n",
|
||||
" if figures: \n",
|
||||
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec.T).T) \n",
|
||||
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=True)) \n",
|
||||
" os.makedirs(OUT_FOLDER, exist_ok=True)\n",
|
||||
" file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
||||
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
|
||||
" ap.save_wav(waveform, out_path)\n",
|
||||
" return alignment, mel_postnet_spec, stop_tokens, waveform"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '../tf_model/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/tts_tf_checkpoint_360000.pkl'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
"# Run FLAGs\n",
|
||||
"use_cuda = True\n",
|
||||
"# Set the vocoder\n",
|
||||
"use_gl = True # use GL if True\n",
|
||||
"BACKEND = 'tf'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false",
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from TTS.utils.text.symbols import symbols, phonemes, make_symbols\n",
|
||||
"from TTS.tf.utils.convert_torch_to_tf_utils import tf_create_dummy_inputs\n",
|
||||
"c = CONFIG\n",
|
||||
"num_speakers = 0\n",
|
||||
"r = 1\n",
|
||||
"num_chars = len(phonemes) if c.use_phonemes else len(symbols)\n",
|
||||
"model = setup_model(num_chars, num_speakers, c)\n",
|
||||
"\n",
|
||||
"# before loading weights you need to run the model once to generate the variables\n",
|
||||
"input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()\n",
|
||||
"mel_pred = model(input_ids, training=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false",
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = load_checkpoint(model, MODEL_PATH)\n",
|
||||
"# model = tf.function(model, experimental_relax_shapes=True)\n",
|
||||
"ap = AudioProcessor(**CONFIG.audio) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# wrapper class to use tf.function\n",
|
||||
"class ModelInference(tf.keras.Model):\n",
|
||||
" def __init__(self, model):\n",
|
||||
" super(ModelInference, self).__init__()\n",
|
||||
" self.model = model\n",
|
||||
" \n",
|
||||
" @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])\n",
|
||||
" def call(self, characters):\n",
|
||||
" return self.model(characters, training=False)\n",
|
||||
" \n",
|
||||
"model = ModelInference(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# LOAD WAVERNN\n",
|
||||
"if use_gl == False:\n",
|
||||
" from parallel_wavegan.models import ParallelWaveGANGenerator, MelGANGenerator\n",
|
||||
" \n",
|
||||
" vocoder_model = MelGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n",
|
||||
" vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n",
|
||||
" vocoder_model.remove_weight_norm()\n",
|
||||
" ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio']) \n",
|
||||
" if use_cuda:\n",
|
||||
" vocoder_model.cuda()\n",
|
||||
" vocoder_model.eval();\n",
|
||||
" print(count_parameters(vocoder_model))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Comparision with https://mycroft.ai/blog/available-voices/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### https://espnet.github.io/icassp2020-tts/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The Commission also recommends\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Other examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Comparison with https://keithito.github.io/audio-samples/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \" He has read the whole thing.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"He reads books.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Thisss isrealy awhsome.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser, Firefox.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser Firefox.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Eren, how are you?\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"source": [
|
||||
"### Hard Sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Encouraged, he started with a minute a day.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"If he decided to watch TV he really watched it.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for twb dataset\n",
|
||||
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"wavs = []\n",
|
||||
"model.eval()\n",
|
||||
"model.decoder.prenet.eval()\n",
|
||||
"model.decoder.max_decoder_steps = 2000\n",
|
||||
"# model.decoder.prenet.train()\n",
|
||||
"speaker_id = None\n",
|
||||
"sentence = '''This is App Store Optimization report.\n",
|
||||
"The first tab on the report is App Details. App details report is updated weekly and Datetime column shows the latest report update date. The widget displays the app icon, respective app version, visual assets on the store, app description, latest app update date on the Appstore/Google PlayStore and what’s new section.\n",
|
||||
"In App Details tab, you can see not only your app but all Delivery Hero apps since we think it can be inspiring to see the other apps, their description and screenshots. \n",
|
||||
"Product name is the actual app name on the AppStore or Google Play Store.\n",
|
||||
"Screenshot URLs column display the actual screenshots on the store for the current version. No resizing is done. If you click on the screenshot, you can see it in full-size.\n",
|
||||
"Current release date show the latest app update date when the query is run. Here we see that Appetito24 Android is updated to app version 4.6.3.2 on 28th of March.\n",
|
||||
"If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n",
|
||||
"If you scroll down in the widget, you can see the older app versions for the same apps. Or you can filter Datetime to see a specific timeframe and the apps’ Store presence back then.\n",
|
||||
"You can also filter for a specific app using Product Name.\n",
|
||||
"If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n",
|
||||
"'''\n",
|
||||
"\n",
|
||||
"for s in sentence.split('\\n'):\n",
|
||||
" print(s)\n",
|
||||
" align, spec, stop_tokens, wav = tts(model, s, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)\n",
|
||||
" wavs = np.concatenate([wavs, np.zeros(int(ap.sample_rate * 0.5)), wav])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"Collapsed": "false"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
|
@ -0,0 +1,2 @@
|
|||
fuzzywuzzy
|
||||
tensorflow>=2.2.0
|
|
@ -0,0 +1,83 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.backend as K
|
||||
|
||||
|
||||
def tf_create_dummy_inputs():
|
||||
""" Create dummy inputs for TF Tacotron2 model """
|
||||
batch_size = 4
|
||||
max_input_length = 32
|
||||
max_mel_length = 128
|
||||
pad = 1
|
||||
n_chars = 24
|
||||
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
|
||||
input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size])
|
||||
input_lengths[-1] = max_input_length
|
||||
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
|
||||
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
|
||||
mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size])
|
||||
mel_lengths[-1] = max_mel_length
|
||||
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
|
||||
return input_ids, input_lengths, mel_outputs, mel_lengths
|
||||
|
||||
|
||||
def compare_torch_tf(torch_tensor, tf_tensor):
|
||||
""" Compute the average absolute difference b/w torch and tf tensors """
|
||||
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
|
||||
|
||||
|
||||
def convert_tf_name(tf_name):
|
||||
""" Convert certain patterns in TF layer names to Torch patterns """
|
||||
tf_name_tmp = tf_name
|
||||
tf_name_tmp = tf_name_tmp.replace(':0', '')
|
||||
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0')
|
||||
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1')
|
||||
tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh')
|
||||
tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight')
|
||||
tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight')
|
||||
tf_name_tmp = tf_name_tmp.replace('/beta', '/bias')
|
||||
tf_name_tmp = tf_name_tmp.replace('/', '.')
|
||||
return tf_name_tmp
|
||||
|
||||
|
||||
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
|
||||
""" Transfer weigths from torch state_dict to TF variables """
|
||||
print(" > Passing weights from Torch to TF ...")
|
||||
for tf_var in tf_vars:
|
||||
torch_var_name = var_map_dict[tf_var.name]
|
||||
print(f' | > {tf_var.name} <-- {torch_var_name}')
|
||||
# if tuple, it is a bias variable
|
||||
if type(torch_var_name) is not tuple:
|
||||
torch_layer_name = '.'.join(torch_var_name.split('.')[-2:])
|
||||
torch_weight = state_dict[torch_var_name]
|
||||
if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name:
|
||||
# out_dim, in_dim, filter -> filter, in_dim, out_dim
|
||||
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
|
||||
elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
# if variable is for bidirectional lstm and it is a bias vector there
|
||||
# needs to be pre-defined two matching torch bias vectors
|
||||
elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif 'rnn' in tf_var.name and 'kernel' in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
elif 'rnn' in tf_var.name and 'bias' in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
else:
|
||||
numpy_weight = torch_weight.detach().cpu().numpy()
|
||||
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||
tf.keras.backend.set_value(tf_var, numpy_weight)
|
||||
|
||||
|
||||
def load_tf_vars(model_tf, tf_vars):
|
||||
for tf_var in tf_vars:
|
||||
model_tf.get_layer(tf_var.name).set_weights(tf_var)
|
||||
return model_tf
|
|
@ -0,0 +1,105 @@
|
|||
import os
|
||||
import re
|
||||
import glob
|
||||
import shutil
|
||||
import datetime
|
||||
import json
|
||||
import subprocess
|
||||
import importlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
from collections import OrderedDict, Counter
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, checkpoint_path)
|
||||
state = {
|
||||
'model': model.weights,
|
||||
'optimizer': optimizer,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(checkpoint_path, 'wb'))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||
chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']])
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if 'r' in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint['r'])
|
||||
return model
|
||||
|
||||
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = np.empty([0, max_len], dtype=np.int8)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
# @tf.custom_gradient
|
||||
def check_gradient(x, grad_clip):
|
||||
x_normed = tf.clip_by_norm(x, grad_clip)
|
||||
grad_norm = tf.norm(grad_clip)
|
||||
return x_normed, grad_norm
|
||||
|
||||
|
||||
def count_parameters(model, c):
|
||||
try:
|
||||
return model.count_params()
|
||||
except:
|
||||
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
|
||||
input_lengths = np.random.randint(100, 129, (8, ))
|
||||
input_lengths[-1] = 128
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
|
||||
mel_spec = np.random.rand(8, 2 * c.r,
|
||||
c.audio['num_mels']).astype('float32')
|
||||
mel_spec = tf.convert_to_tensor(mel_spec)
|
||||
speaker_ids = np.random.randint(
|
||||
0, 5, (8, )) if c.use_speaker_embedding else None
|
||||
_ = model(input_dummy, input_lengths, mel_spec)
|
||||
return model.count_params()
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
raise NotImplemented(' [!] Tacotron model is not ready.')
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
return model
|
|
@ -0,0 +1,8 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def shape_list(x):
|
||||
"""Deal with dynamic shape in tensorflow cleanly."""
|
||||
static = x.shape.as_list()
|
||||
dynamic = tf.shape(x)
|
||||
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
|
@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import pkg_resources
|
||||
installed = {pkg.key for pkg in pkg_resources.working_set}
|
||||
if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
from .text import text_to_sequence, phoneme_to_sequence
|
||||
|
@ -14,23 +18,32 @@ def text_to_seqvec(text, CONFIG, use_cuda):
|
|||
dtype=np.int32)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
|
||||
# torch tensor
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
return chars_var.long()
|
||||
return seq
|
||||
|
||||
|
||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
||||
if cuda:
|
||||
return tensor.cuda()
|
||||
return tensor
|
||||
|
||||
|
||||
def numpy_to_tf(np_array, dtype):
|
||||
if np_array is None:
|
||||
return None
|
||||
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap, use_cuda):
|
||||
print(style_wav)
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(
|
||||
ap.load_wav(style_wav))).unsqueeze(0)
|
||||
if use_cuda:
|
||||
return style_mel.cuda()
|
||||
style_mel = ap.melspectrogram(
|
||||
ap.load_wav(style_wav)).expand_dims(0)
|
||||
return style_mel
|
||||
|
||||
|
||||
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
||||
|
@ -44,11 +57,31 @@ def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None)
|
|||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs(postnet_output, decoder_output, alignments):
|
||||
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst:
|
||||
raise NotImplemented(' [!] GST inference not implemented for TF')
|
||||
if truncated:
|
||||
raise NotImplemented(' [!] Truncated inference not implemented for TF')
|
||||
# TODO: handle multispeaker case
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
inputs)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
return postnet_output, decoder_output, alignment
|
||||
stop_tokens = stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].numpy()
|
||||
decoder_output = decoder_output[0].numpy()
|
||||
alignment = alignments[0].numpy()
|
||||
stop_tokens = stop_tokens[0].numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
|
@ -98,7 +131,8 @@ def synthesis(model,
|
|||
truncated=False,
|
||||
enable_eos_bos_chars=False, #pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False):
|
||||
do_trim_silence=False,
|
||||
backend='torch'):
|
||||
"""Synthesize voice for the given text.
|
||||
|
||||
Args:
|
||||
|
@ -114,6 +148,7 @@ def synthesis(model,
|
|||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
do_trim_silence (bool): trim silence after synthesis.
|
||||
backend (str): tf or torch
|
||||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
|
@ -121,15 +156,29 @@ def synthesis(model,
|
|||
style_mel = compute_style_mel(style_wav, ap, use_cuda)
|
||||
# preprocess the given text
|
||||
inputs = text_to_seqvec(text, CONFIG, use_cuda)
|
||||
speaker_id = id_to_torch(speaker_id)
|
||||
if speaker_id is not None and use_cuda:
|
||||
speaker_id = speaker_id.cuda()
|
||||
# pass tensors to backend
|
||||
if backend == 'torch':
|
||||
speaker_id = id_to_torch(speaker_id)
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
else:
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
# synthesize voice
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
if backend == 'torch':
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
# convert outputs to numpy
|
||||
postnet_output, decoder_output, alignment = parse_outputs(
|
||||
postnet_output, decoder_output, alignments)
|
||||
# plot results
|
||||
wav = None
|
||||
if use_griffin_lim:
|
||||
|
|
|
@ -61,7 +61,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
|||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
# plot stopnet predictions
|
||||
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
|
||||
plt.subplot(num_plot, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
# plot postnet spectrogram
|
||||
|
|
Loading…
Reference in New Issue