Merge branch 'tf-convert2' into dev

pull/10/head
erogol 2020-05-18 13:13:21 +02:00
commit df8fd3823d
23 changed files with 1913 additions and 131 deletions

View File

@ -1,5 +1,5 @@
{
"model": "Tacotron2",
"model": "Tacotron2",
"run_name": "ljspeech",
"run_description": "tacotron2",
@ -11,12 +11,12 @@
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Silence trimming
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
@ -26,7 +26,7 @@
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// MelSpectrogram parameters
"num_mels": 80, // size of the mel spec frame.
"num_mels": 80, // size of the mel spec frame.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
@ -50,7 +50,7 @@
// "punctuations":"!'(),-.:;? ",
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
// },
// DISTRIBUTED TRAINING
"distributed":{
"backend": "nccl",
@ -61,8 +61,8 @@
// TRAINING
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
@ -80,11 +80,11 @@
"wd": 0.000001, // Weight decay weight.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
// TACOTRON PRENET
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
"prenet_type": "original", // "original" or "bn".
"prenet_dropout": true, // enable/disable dropout at prenet.
"prenet_dropout": true, // enable/disable dropout at prenet.
// ATTENTION
"attention_type": "original", // 'original' or 'graves'
@ -98,16 +98,16 @@
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
// STOPNET
"stopnet": true, // Train stopnet predicting the end of synthesis.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
// TENSORBOARD and LOGGING
"print_step": 25, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values in evalulation.
"print_eval": false, // If True, it prints loss values in evalulation.
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
// DATA LOADING
"text_cleaner": "phoneme_cleaners",
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
@ -119,7 +119,7 @@
// PATHS
"output_path": "/home/erogol/Models/LJSpeech/",
// PHONEMES
"phoneme_cache_path": "mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.

View File

@ -33,7 +33,7 @@ class LinearBN(nn.Module):
super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self.bn = nn.BatchNorm1d(out_features)
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
self._init_w(init_gain)
def _init_w(self, init_gain):
@ -45,7 +45,7 @@ class LinearBN(nn.Module):
out = self.linear_layer(x)
if len(out.shape) == 3:
out = out.permute(1, 2, 0)
out = self.bn(out)
out = self.batch_normalization(out)
if len(out.shape) == 3:
out = out.permute(2, 0, 1)
return out
@ -63,18 +63,18 @@ class Prenet(nn.Module):
self.prenet_dropout = prenet_dropout
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.layers = nn.ModuleList([
self.linear_layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
elif prenet_type == "original":
self.layers = nn.ModuleList([
self.linear_layers = nn.ModuleList([
Linear(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
for linear in self.linear_layers:
if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
else:
@ -93,7 +93,7 @@ class LocationLayer(nn.Module):
attention_n_filters=32,
attention_kernel_size=31):
super(LocationLayer, self).__init__()
self.location_conv = nn.Conv1d(
self.location_conv1d = nn.Conv1d(
in_channels=2,
out_channels=attention_n_filters,
kernel_size=attention_kernel_size,
@ -104,7 +104,7 @@ class LocationLayer(nn.Module):
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
def forward(self, attention_cat):
processed_attention = self.location_conv(attention_cat)
processed_attention = self.location_conv1d(attention_cat)
processed_attention = self.location_dense(
processed_attention.transpose(1, 2))
return processed_attention

View File

@ -6,130 +6,128 @@ from .common_layers import init_attn, Prenet, Linear
class ConvBNBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
super(ConvBNBlock, self).__init__()
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
conv1d = nn.Conv1d(in_channels,
self.convolution1d = nn.Conv1d(in_channels,
out_channels,
kernel_size,
padding=padding)
norm = nn.BatchNorm1d(out_channels)
dropout = nn.Dropout(p=0.5)
if nonlinear == 'relu':
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
elif nonlinear == 'tanh':
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
self.dropout = nn.Dropout(p=0.5)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'tanh':
self.activation = nn.Tanh()
else:
self.net = nn.Sequential(conv1d, norm, dropout)
self.activation = nn.Identity()
def forward(self, x):
output = self.net(x)
return output
o = self.convolution1d(x)
o = self.batch_normalization(o)
o = self.activation(o)
o = self.dropout(o)
return o
class Postnet(nn.Module):
def __init__(self, mel_dim, num_convs=5):
def __init__(self, output_dim, num_convs=5):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
for _ in range(1, num_convs - 1):
self.convolutions.append(
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
self.convolutions.append(
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
def forward(self, x):
o = x
for layer in self.convolutions:
x = layer(x)
return x
o = layer(o)
return o
class Encoder(nn.Module):
def __init__(self, in_features=512):
def __init__(self, output_input_dim=512):
super(Encoder, self).__init__()
convolutions = []
self.convolutions = nn.ModuleList()
for _ in range(3):
convolutions.append(
ConvBNBlock(in_features, in_features, 5, 'relu'))
self.convolutions = nn.Sequential(*convolutions)
self.lstm = nn.LSTM(in_features,
int(in_features / 2),
self.convolutions.append(
ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
self.lstm = nn.LSTM(output_input_dim,
int(output_input_dim / 2),
num_layers=1,
batch_first=True,
bias=True,
bidirectional=True)
self.rnn_state = None
def forward(self, x, input_lengths):
x = self.convolutions(x)
x = x.transpose(1, 2)
x = nn.utils.rnn.pack_padded_sequence(x,
o = x
for layer in self.convolutions:
o = layer(o)
o = o.transpose(1, 2)
o = nn.utils.rnn.pack_padded_sequence(o,
input_lengths,
batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs,
batch_first=True,
)
return outputs
o, _ = self.lstm(o)
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
return o
def inference(self, x):
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
def inference_truncated(self, x):
"""
Preserve encoder state for continuous inference
"""
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
return outputs
o = x
for layer in self.convolutions:
o = layer(o)
o = o.transpose(1, 2)
# self.lstm.flatten_parameters()
o, _ = self.lstm(o)
return o
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm,
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, attn_K, separate_stopnet,
speaker_embedding_dim):
super(Decoder, self).__init__()
self.memory_dim = memory_dim
self.frame_dim = frame_dim
self.r_init = r
self.r = r
self.encoder_embedding_dim = in_features
self.encoder_embedding_dim = input_dim
self.separate_stopnet = separate_stopnet
self.max_decoder_steps = 1000
self.gate_threshold = 0.5
# model dimensions
self.query_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.max_decoder_steps = 1000
self.gate_threshold = 0.5
self.attn_dim = 128
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
# memory -> |Prenet| -> processed_memory
prenet_dim = self.memory_dim
self.prenet = Prenet(
prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
prenet_dim = self.frame_dim
self.prenet = Prenet(prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.query_dim)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
self.query_dim,
bias=True)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_features,
embedding_dim=input_dim,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
@ -141,15 +139,16 @@ class Decoder(nn.Module):
forward_attn_mask=forward_attn_mask,
attn_K=attn_K)
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
self.decoder_rnn_dim, 1)
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
self.decoder_rnn_dim,
bias=True)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
self.memory_dim * self.r_init)
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
self.frame_dim * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
@ -161,7 +160,7 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(B,
self.memory_dim * self.r)
self.frame_dim * self.r)
return memory
def _init_states(self, inputs, mask, keep_states=False):
@ -187,9 +186,9 @@ class Decoder(nn.Module):
Reshape the spectrograms for given 'r'
"""
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
if memory.size(-1) == self.frame_dim:
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
# Time first (T_decoder, B, memory_dim)
# Time first (T_decoder, B, frame_dim)
memory = memory.transpose(0, 1)
return memory
@ -197,22 +196,22 @@ class Decoder(nn.Module):
alignments = torch.stack(alignments).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(outputs.size(0), -1, self.memory_dim)
outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
outputs = outputs.transpose(1, 2)
return outputs, stop_tokens, alignments
def _update_memory(self, memory):
if len(memory.shape) == 2:
return memory[:, self.memory_dim * (self.r - 1):]
return memory[:, :, self.memory_dim * (self.r - 1):]
return memory[:, self.frame_dim * (self.r - 1):]
return memory[:, :, self.frame_dim * (self.r - 1):]
def decode(self, memory):
'''
shapes:
- memory: B x r * self.memory_dim
- memory: B x r * self.frame_dim
'''
# self.context: B x D_en
# query_input: B x D_en + (r * self.memory_dim)
# query_input: B x D_en + (r * self.frame_dim)
query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn(
@ -235,16 +234,16 @@ class Decoder(nn.Module):
# B x (D_decoder_rnn + D_en)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1)
# B x (self.r * self.memory_dim)
# B x (self.r * self.frame_dim)
decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.memory_dim))
# B x (D_decoder_rnn + (self.r * self.frame_dim))
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach())
else:
stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.memory_dim]
decoder_output = decoder_output[:, :self.r * self.frame_dim]
return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask, speaker_embeddings=None):

View File

@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
super(Tacotron2, self).__init__()
self.postnet_output_dim = postnet_output_dim
self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r
self.r = r
self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 if num_speakers > 1 else 512
encoder_dim = 512 if num_speakers > 1 else 512

View File

@ -6,7 +6,8 @@ import torch as T
from TTS.server.synthesizer import Synthesizer
from TTS.tests import get_tests_input_path, get_tests_output_path
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config, save_checkpoint
class DemoServerTest(unittest.TestCase):

View File

@ -5,7 +5,7 @@ import torch
import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.utils.io import load_config
from TTS.utils.audio import AudioProcessor
from TTS.datasets import TTSDataset
from TTS.datasets.preprocess import ljspeech

View File

@ -6,7 +6,7 @@ import numpy as np
from torch import optim
from torch import nn
from TTS.utils.generic_utils import load_config
from TTS.utils.io import load_config
from TTS.layers.losses import MSELossMasked
from TTS.models.tacotron2 import Tacotron2

View File

@ -0,0 +1,59 @@
import os
import copy
import torch
import unittest
import numpy as np
import tensorflow as tf
from torch import optim
from torch import nn
from TTS.utils.io import load_config
from TTS.layers.losses import MSELossMasked
from TTS.tf.models.tacotron2 import Tacotron2
#pylint: disable=unused-variable
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTFTrainTest(unittest.TestCase):
def test_train_step(self):
''' test forward pass '''
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
input = tf.convert_to_tensor(input.cpu().numpy())
input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy())
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
# training pass
output = model(input, input_lengths, mel_spec, training=True)
# check model output shapes
assert np.all(output[0].shape == mel_spec.shape)
assert np.all(output[1].shape == mel_spec.shape)
assert output[2].shape[2] == input.shape[1]
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
# inference pass
output = model(input, training=False)

View File

@ -5,7 +5,7 @@ import unittest
from torch import optim
from torch import nn
from TTS.utils.generic_utils import load_config
from TTS.utils.io import load_config
from TTS.layers.losses import L1LossMasked
from TTS.models.tacotron import Tacotron

View File

@ -5,7 +5,7 @@ import os
import unittest
from TTS.utils.text import *
from TTS.tests import get_tests_path
from TTS.utils.generic_utils import load_config
from TTS.utils.io import load_config
TESTS_PATH = get_tests_path()
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
@ -92,4 +92,4 @@ def test_text2phone():
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
lang = "en-us"
ph = text2phone(text, lang)
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"

12
tf/README.md Normal file
View File

@ -0,0 +1,12 @@
## Utilities to Convert Models to Tensorflow2
Here there are utilities to convert trained Torch models to Tensorflow (2.2>=).
We currently support Tacotron2 with Location Sensitive Attention.
Be aware that our old Torch models may not work with this module due to additional changes in layer naming convention. Therefore, you need to train new models or handle these changes.
We do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome.
To see how you can use TF model at inference, check the notebook.
This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own.

View File

@ -0,0 +1,196 @@
# %%
import sys
sys.path.append('/home/erogol/Projects')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# %%
import argparse
import numpy as np
import torch
import tensorflow as tf
from fuzzywuzzy import fuzz
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.generic_utils import setup_model, count_parameters
from TTS.utils.io import load_config
from TTS_tf.models.tacotron2 import Tacotron2
from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
from TTS_tf.utils.generic_utils import save_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument(
'--torch_model_path',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument(
'--output_path',
type=str,
help='path to save TF model weights.')
args = parser.parse_args()
# load model config
config_path = args.config_path
c = load_config(config_path)
num_speakers = 0
# init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device('cpu'))
state_dict = checkpoint['model']
model.load_state_dict(state_dict)
# init tf model
model_tf = Tacotron2(num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
var_map = [
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
('decoder/linear_projection/kernel:0', 'decoder.linear_projection.linear_layer.weight'),
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
]
# %%
# get tf_model graph
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
mel_pred = model_tf(input_ids, training=False)
# get tf variables
tf_vars = model_tf.weights
# match variable names with fuzzy logic
torch_var_names = list(state_dict.keys())
tf_var_names = [we.name for we in model_tf.weights]
for tf_name in tf_var_names:
# skip re-mapped layer names
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [fuzz.ratio(torch_name, tf_name_edited) for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
var_map.append((tf_name, matching_name))
# %%
# print variable match
from pprint import pprint
pprint(var_map)
pprint(torch_var_names)
# pass weights
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
# Compare TF and TORCH models
# %%
# check embedding outputs
model.eval()
input_ids = torch.randint(0, 24, (1, 128)).long()
o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
# compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1,2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
# compare decoder.attention_rnn
inp = torch.rand([1, 768])
inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None)
output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.attention
query = output
inputs = torch.rand([1, 128, 512])
query_tf = query.detach().numpy()
inputs_tf = inputs.numpy()
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf)
context_tf = model_tf.decoder.attention(query_tf, training=False)
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
assert compare_torch_tf(context, context_tf) < 1e-5
# compare decoder.decoder_rnn
input = torch.rand([1, 1536])
input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None)
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.linear_projection
input = torch.rand([1, 1536])
input_tf = input.numpy()
output = model.decoder.linear_projection(input)
output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
assert compare_torch_tf(output, output_tf) < 1e-5
# compare decoder outputs
model.decoder.max_decoder_steps = 100
model_tf.decoder.set_max_decoder_steps(100)
output, align, stop = model.decoder.inference(oo_en)
states = model_tf.decoder.build_decoder_initial_states(1,512,128)
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
assert compare_torch_tf(output.transpose(1,2), output_tf) < 1e-4
# compare the whole model output
outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean() )
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
# %%
# save tf model
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
checkpoint['r'], args.output_path)
print(' > Model conversion is successfully completed :).')

258
tf/layers/common_layers.py Normal file
View File

@ -0,0 +1,258 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention
from TTS.tf.utils.tf_utils import shape_list
class Linear(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super(Linear, self).__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
self.activation = keras.layers.ReLU()
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
return self.activation(self.linear_layer(x))
class LinearBN(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super(LinearBN, self).__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
self.activation = keras.layers.ReLU()
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
out = self.linear_layer(x)
out = self.batch_normalization(out, training=training)
return self.activation(out)
class Prenet(keras.layers.Layer):
def __init__(self,
prenet_type,
prenet_dropout,
units,
bias,
**kwargs):
super(Prenet, self).__init__(**kwargs)
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.linear_layers = []
if prenet_type == "bn":
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
elif prenet_type == "original":
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
else:
raise RuntimeError(' [!] Unknown prenet type.')
if prenet_dropout:
self.dropout = keras.layers.Dropout(rate=0.5)
def call(self, x, training=None):
"""
shapes:
x: B x T x C
"""
for linear in self.linear_layers:
if self.prenet_dropout:
x = self.dropout(linear(x), training=training)
else:
x = linear(x)
return x
def _sigmoid_norm(score):
attn_weights = tf.nn.sigmoid(score)
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
return attn_weights
class Attention(keras.layers.Layer):
"""TODO: implement forward_attention"""
"""TODO: location sensitive attention"""
"""TODO: implement attention windowing """
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
use_trans_agent, use_forward_attn_mask, **kwargs):
super(Attention, self).__init__(**kwargs)
self.use_loc_attn = use_loc_attn
self.loc_attn_n_filters = loc_attn_n_filters
self.loc_attn_kernel_size = loc_attn_kernel_size
self.use_windowing = use_windowing
self.norm = norm
self.use_forward_attn = use_forward_attn
self.use_trans_agent = use_trans_agent
self.use_forward_attn_mask = use_forward_attn_mask
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
if use_loc_attn:
self.location_conv1d = keras.layers.Conv1D(
filters=loc_attn_n_filters,
kernel_size=loc_attn_kernel_size,
padding='same',
use_bias=False,
name='location_layer/location_conv1d')
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
if norm == 'softmax':
self.norm_func = tf.nn.softmax
elif norm == 'sigmoid':
self.norm_func = _sigmoid_norm
else:
raise ValueError("Unknown value for attention norm type")
def init_states(self, batch_size, value_length):
states = ()
if self.use_loc_attn:
attention_cum = tf.zeros([batch_size, value_length])
attention_old = tf.zeros([batch_size, value_length])
states = (attention_cum, attention_old)
return states
def process_values(self, values):
""" cache values for decoder iterations """
self.processed_values = self.inputs_layer(values)
self.values = values
def get_loc_attn(self, query, states):
""" compute location attention, query layer and
unnorm. attention weights"""
attention_cum, attention_old = states
attn_cat = tf.stack([attention_old, attention_cum],
axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1))
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
score = self.v(
tf.nn.tanh(self.processed_values + processed_query +
processed_attn))
score = tf.squeeze(score, axis=2)
return score, processed_query
def get_attn(self, query):
""" compute query layer and unnormalized attention weights """
processed_query = self.query_layer(tf.expand_dims(query, 1))
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
score = tf.squeeze(score, axis=2)
return score, processed_query
def apply_score_masking(self, score, mask):
""" ignore sequence paddings """
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# Bias so padding positions do not contribute to attention distribution.
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def call(self, query, states):
"""
shapes:
query: B x D
"""
if self.use_loc_attn:
score, processed_query = self.get_loc_attn(query, states)
else:
score, processed_query = self.get_attn(query)
# TODO: masking
# if mask is not None:
# self.apply_score_masking(score, mask)
# attn_weights shape == (batch_size, max_length, 1)
attn_weights = self.norm_func(score)
# update attention states
if self.use_loc_attn:
states = (states[0] + attn_weights, attn_weights)
else:
states = ()
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
context_vector = tf.squeeze(context_vector, axis=1)
return context_vector, attn_weights, states
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
# dtype = processed_query.dtype
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
# class LocationSensitiveAttention(BahdanauAttention):
# def __init__(self,
# units,
# memory=None,
# memory_sequence_length=None,
# normalize=False,
# probability_fn="softmax",
# kernel_initializer="glorot_uniform",
# dtype=None,
# name="LocationSensitiveAttention",
# location_attention_filters=32,
# location_attention_kernel_size=31):
# super(LocationSensitiveAttention,
# self).__init__(units=units,
# memory=memory,
# memory_sequence_length=memory_sequence_length,
# normalize=normalize,
# probability_fn='softmax', ## parent module default
# kernel_initializer=kernel_initializer,
# dtype=dtype,
# name=name)
# if probability_fn == 'sigmoid':
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
# self.location_dense = keras.layers.Dense(units, use_bias=False)
# # self.v = keras.layers.Dense(1, use_bias=True)
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
# processed_query = tf.expand_dims(processed_query, 1)
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
# def _location_sensitive(self, alignment_cum, alignment_old):
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
# return self.location_dense(self.location_conv(alignment_cat))
# def _sigmoid_normalization(self, score):
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
# # def _apply_masking(self, score, mask):
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# # # Bias so padding positions do not contribute to attention distribution.
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
# # return score
# def _calculate_attention(self, query, state):
# alignment_cum, alignment_old = state[:2]
# processed_query = self.query_layer(
# query) if self.query_layer else query
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
# score = self._location_sensitive_score(
# processed_query,
# self.keys,
# processed_loc)
# alignment = self.probability_fn(score, state)
# alignment_cum = alignment_cum + alignment
# state[0] = alignment_cum
# state[1] = alignment
# return alignment, state
# def compute_context(self, alignments):
# expanded_alignments = tf.expand_dims(alignments, 1)
# context = tf.matmul(expanded_alignments, self.values)
# context = tf.squeeze(context, [1])
# return context
# # def call(self, query, state):
# # alignment, next_state = self._calculate_attention(query, state)
# # return alignment, next_state

231
tf/layers/tacotron2.py Normal file
View File

@ -0,0 +1,231 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tf.utils.tf_utils import shape_list
from TTS.tf.layers.common_layers import Prenet, Attention
# from tensorflow_addons.seq2seq import AttentionWrapper
class ConvBNBlock(keras.layers.Layer):
def __init__(self, filters, kernel_size, activation, **kwargs):
super(ConvBNBlock, self).__init__(**kwargs)
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
self.activation = keras.layers.Activation(activation, name='activation')
def call(self, x, training=None):
o = self.convolution1d(x)
o = self.batch_normalization(o, training=training)
o = self.activation(o)
o = self.dropout(o, training=training)
return o
class Postnet(keras.layers.Layer):
def __init__(self, output_filters, num_convs, **kwargs):
super(Postnet, self).__init__(**kwargs)
self.convolutions = []
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
for idx in range(1, num_convs - 1):
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
return o
class Encoder(keras.layers.Layer):
def __init__(self, output_input_dim, **kwargs):
super(Encoder, self).__init__(**kwargs)
self.convolutions = []
for idx in range(3):
self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')
def call(self, x, training=None):
o = x
for layer in self.convolutions:
o = layer(o, training=training)
o = self.lstm(o)
return o
class Decoder(keras.layers.Layer):
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
super(Decoder, self).__init__(**kwargs)
self.frame_dim = frame_dim
self.r_init = tf.constant(r, dtype=tf.int32)
self.r = tf.constant(r, dtype=tf.int32)
self.separate_stopnet = separate_stopnet
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
# model dimensions
self.query_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.attn_dim = 128
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(prenet_type,
prenet_dropout,
[self.prenet_dim, self.prenet_dim],
bias=False,
name='prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
self.attention = Attention(attn_dim=self.attn_dim,
use_loc_attn=True,
loc_attn_n_filters=32,
loc_attn_kernel_size=31,
use_windowing=False,
norm=attn_norm,
use_forward_attn=use_forward_attn,
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name='attention')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
def set_max_decoder_steps(self, new_max_steps):
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
def set_r(self, new_r):
self.r = tf.constant(new_r, dtype=tf.int32)
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
zero_frame = tf.zeros([batch_size, self.frame_dim])
zero_context = tf.zeros([batch_size, memory_dim])
attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32)
attention_states = self.attention.init_states(batch_size, memory_length)
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
def step(self, prenet_next, states,
memory_seq_length=None, training=None):
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
attention_rnn_output, attention_rnn_state = \
self.attention_rnn(attention_rnn_input,
attention_rnn_state, training=training)
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
decoder_rnn_output, decoder_rnn_state = \
self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
output_frame = self.linear_projection(linear_projection_input, training=training)
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
stopnet_output = self.stopnet(stopnet_input, training=training)
output_frame = output_frame[:, :self.r * self.frame_dim]
states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
return output_frame, stopnet_output, states, attention
def decode(self, memory, states, frames, memory_seq_length=None):
B, T, D = shape_list(memory)
num_iter = shape_list(frames)[1] // self.r
# init states
frame_zero = tf.expand_dims(states[0], 1)
frames = tf.concat([frame_zero, frames], axis=1)
outputs = tf.TensorArray(dtype=tf.float32, size=num_iter)
attentions = tf.TensorArray(dtype=tf.float32, size=num_iter)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter)
# pre-computes
self.attention.process_values(memory)
prenet_output = self.prenet(frames, training=True)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
prenet_next = prenet_output[:, step]
output, stop_token, states, attention = self.step(prenet_next,
states,
memory_seq_length)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
_, memory, _, states, outputs, stop_tokens, attentions = \
tf.while_loop(lambda *arg: True,
_body,
loop_vars=(step_count, memory, prenet_output, states, outputs,
stop_tokens, attentions),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=num_iter)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0 ,2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def decode_inference(self, memory, states):
B, T, D = shape_list(memory)
# init states
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
# pre-computes
self.attention.process_values(memory)
# iter vars
stop_flag = tf.constant(False, dtype=tf.bool)
step_count = tf.constant(0, dtype=tf.int32)
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, attention = self.step(prenet_next,
states,
None,
training=False)
stop_token = tf.math.sigmoid(stop_token)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \
tf.while_loop(cond,
_body,
loop_vars=(step_count, memory, states, outputs,
stop_tokens, attentions, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps)
outputs = outputs.stack()
attentions = attentions.stack()
stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
if training:
return self.decode(memory, states, frames, memory_seq_length)
return self.decode_inference(memory, states)

72
tf/models/tacotron2.py Normal file
View File

@ -0,0 +1,72 @@
import tensorflow as tf
from tensorflow import keras
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.tf.utils.tf_utils import shape_list
class Tacotron2(keras.models.Model):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron2, self).__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.bidirectional_decoder = bidirectional_decoder
self.num_speakers = num_speakers
self.speaker_embed_dim = 256
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
self.encoder = Encoder(512, name='encoder')
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type,
prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim)
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
def call(self, characters, text_lengths=None, frames=None, training=None):
if training == True:
return self.training(characters, text_lengths, frames)
else:
return self.inference(characters)
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
def inference(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens

View File

@ -0,0 +1,708 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"This is to test TTS models with benchmark sentences for speech synthesis.\n",
"\n",
"Before running this script please DON'T FORGET: \n",
"- to set file paths.\n",
"- to download related model files.\n",
"- download or clone related repos, linked below.\n",
"- setup the repositories. ```python setup.py install```\n",
"- to checkout right commit versions (given next to the model in the models page).\n",
"- to set the file paths below.\n",
"\n",
"Repositories:\n",
"- TTS: https://github.com/mozilla/TTS"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"scrolled": true
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"\n",
"# you may need to change this depending on your system\n",
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
"\n",
"import sys\n",
"import io\n",
"import torch \n",
"import tensorflow as tf\n",
"print(tf.config.list_physical_devices('GPU'))\n",
"\n",
"import time\n",
"import json\n",
"import yaml\n",
"import numpy as np\n",
"from collections import OrderedDict\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams[\"figure.figsize\"] = (16,5)\n",
"\n",
"import librosa\n",
"import librosa.display\n",
"\n",
"from TTS.tf.models.tacotron2 import Tacotron2\n",
"from TTS.tf.utils.generic_utils import setup_model, load_checkpoint\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n",
"from TTS.utils.synthesis import synthesis\n",
"from TTS.utils.visual import visualize\n",
"\n",
"import IPython\n",
"from IPython.display import Audio\n",
"\n",
"%matplotlib agg"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
" t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, None, None, False, CONFIG.enable_eos_bos_chars, use_gl, backend=BACKEND)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" # coorect the normalization differences b/w TTS and the Vocoder.\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" print(mel_postnet_spec.shape)\n",
" print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n",
" if not use_gl:\n",
" waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T\n",
" if use_cuda and not use_gl:\n",
" waveform = waveform.cpu()\n",
" waveform = waveform.numpy()\n",
" waveform = waveform.squeeze()\n",
" rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n",
" print(waveform.shape)\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
" print(\" > Real-time factor: {}\".format(rtf))\n",
" if figures: \n",
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec.T).T) \n",
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=True)) \n",
" os.makedirs(OUT_FOLDER, exist_ok=True)\n",
" file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
" ap.save_wav(waveform, out_path)\n",
" return alignment, mel_postnet_spec, stop_tokens, waveform"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '../tf_model/'\n",
"MODEL_PATH = ROOT_PATH + '/tts_tf_checkpoint_360000.pkl'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
"# Run FLAGs\n",
"use_cuda = True\n",
"# Set the vocoder\n",
"use_gl = True # use GL if True\n",
"BACKEND = 'tf'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"scrolled": true
},
"outputs": [],
"source": [
"from TTS.utils.text.symbols import symbols, phonemes, make_symbols\n",
"from TTS.tf.utils.convert_torch_to_tf_utils import tf_create_dummy_inputs\n",
"c = CONFIG\n",
"num_speakers = 0\n",
"r = 1\n",
"num_chars = len(phonemes) if c.use_phonemes else len(symbols)\n",
"model = setup_model(num_chars, num_speakers, c)\n",
"\n",
"# before loading weights you need to run the model once to generate the variables\n",
"input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()\n",
"mel_pred = model(input_ids, training=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false",
"scrolled": true
},
"outputs": [],
"source": [
"model = load_checkpoint(model, MODEL_PATH)\n",
"# model = tf.function(model, experimental_relax_shapes=True)\n",
"ap = AudioProcessor(**CONFIG.audio) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"# wrapper class to use tf.function\n",
"class ModelInference(tf.keras.Model):\n",
" def __init__(self, model):\n",
" super(ModelInference, self).__init__()\n",
" self.model = model\n",
" \n",
" @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])\n",
" def call(self, characters):\n",
" return self.model(characters, training=False)\n",
" \n",
"model = ModelInference(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"# LOAD WAVERNN\n",
"if use_gl == False:\n",
" from parallel_wavegan.models import ParallelWaveGANGenerator, MelGANGenerator\n",
" \n",
" vocoder_model = MelGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n",
" vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n",
" vocoder_model.remove_weight_norm()\n",
" ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio']) \n",
" if use_cuda:\n",
" vocoder_model.cuda()\n",
" vocoder_model.eval();\n",
" print(count_parameters(vocoder_model))"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Comparision with https://mycroft.ai/blog/available-voices/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasnt absolutely certain it was, he just let it go.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### https://espnet.github.io/icassp2020-tts/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"The Commission also recommends\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Other examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Comparison with https://keithito.github.io/audio-samples/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Heres a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \" He has read the whole thing.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"He reads books.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Thisss isrealy awhsome.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"This is your internet browser, Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"This is your internet browser Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Eren, how are you?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"### Hard Sentences"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Encouraged, he started with a minute a day.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"If he decided to watch TV he really watched it.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"# for twb dataset\n",
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"wavs = []\n",
"model.eval()\n",
"model.decoder.prenet.eval()\n",
"model.decoder.max_decoder_steps = 2000\n",
"# model.decoder.prenet.train()\n",
"speaker_id = None\n",
"sentence = '''This is App Store Optimization report.\n",
"The first tab on the report is App Details. App details report is updated weekly and Datetime column shows the latest report update date. The widget displays the app icon, respective app version, visual assets on the store, app description, latest app update date on the Appstore/Google PlayStore and whats new section.\n",
"In App Details tab, you can see not only your app but all Delivery Hero apps since we think it can be inspiring to see the other apps, their description and screenshots. \n",
"Product name is the actual app name on the AppStore or Google Play Store.\n",
"Screenshot URLs column display the actual screenshots on the store for the current version. No resizing is done. If you click on the screenshot, you can see it in full-size.\n",
"Current release date show the latest app update date when the query is run. Here we see that Appetito24 Android is updated to app version 4.6.3.2 on 28th of March.\n",
"If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n",
"If you scroll down in the widget, you can see the older app versions for the same apps. Or you can filter Datetime to see a specific timeframe and the apps Store presence back then.\n",
"You can also filter for a specific app using Product Name.\n",
"If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n",
"'''\n",
"\n",
"for s in sentence.split('\\n'):\n",
" print(s)\n",
" align, spec, stop_tokens, wav = tts(model, s, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)\n",
" wavs = np.concatenate([wavs, np.zeros(int(ap.sample_rate * 0.5)), wav])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

2
tf/requirements Normal file
View File

@ -0,0 +1,2 @@
fuzzywuzzy
tensorflow>=2.2.0

View File

@ -0,0 +1,83 @@
import numpy as np
import torch
import re
import tensorflow as tf
import tensorflow.keras.backend as K
def tf_create_dummy_inputs():
""" Create dummy inputs for TF Tacotron2 model """
batch_size = 4
max_input_length = 32
max_mel_length = 128
pad = 1
n_chars = 24
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size])
input_lengths[-1] = max_input_length
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size])
mel_lengths[-1] = max_mel_length
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
return input_ids, input_lengths, mel_outputs, mel_lengths
def compare_torch_tf(torch_tensor, tf_tensor):
""" Compute the average absolute difference b/w torch and tf tensors """
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
def convert_tf_name(tf_name):
""" Convert certain patterns in TF layer names to Torch patterns """
tf_name_tmp = tf_name
tf_name_tmp = tf_name_tmp.replace(':0', '')
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0')
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1')
tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh')
tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight')
tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight')
tf_name_tmp = tf_name_tmp.replace('/beta', '/bias')
tf_name_tmp = tf_name_tmp.replace('/', '.')
return tf_name_tmp
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
""" Transfer weigths from torch state_dict to TF variables """
print(" > Passing weights from Torch to TF ...")
for tf_var in tf_vars:
torch_var_name = var_map_dict[tf_var.name]
print(f' | > {tf_var.name} <-- {torch_var_name}')
# if tuple, it is a bias variable
if type(torch_var_name) is not tuple:
torch_layer_name = '.'.join(torch_var_name.split('.')[-2:])
torch_weight = state_dict[torch_var_name]
if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name:
# out_dim, in_dim, filter -> filter, in_dim, out_dim
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
# if variable is for bidirectional lstm and it is a bias vector there
# needs to be pre-defined two matching torch bias vectors
elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif 'rnn' in tf_var.name and 'kernel' in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
elif 'rnn' in tf_var.name and 'bias' in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
else:
numpy_weight = torch_weight.detach().cpu().numpy()
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
tf.keras.backend.set_value(tf_var, numpy_weight)
def load_tf_vars(model_tf, tf_vars):
for tf_var in tf_vars:
model_tf.get_layer(tf_var.name).set_weights(tf_var)
return model_tf

105
tf/utils/generic_utils.py Normal file
View File

@ -0,0 +1,105 @@
import os
import re
import glob
import shutil
import datetime
import json
import subprocess
import importlib
import pickle
import numpy as np
from collections import OrderedDict, Counter
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step)
checkpoint_path = os.path.join(output_folder, checkpoint_path)
state = {
'model': model.weights,
'optimizer': optimizer,
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
}
state.update(kwargs)
pickle.dump(state, open(checkpoint_path, 'wb'))
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']])
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if 'r' in checkpoint.keys():
model.decoder.set_r(checkpoint['r'])
return model
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.max()
batch_size = sequence_length.size(0)
seq_range = np.empty([0, max_len], dtype=np.int8)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
# B x T_max
return seq_range_expand < seq_length_expand
# @tf.custom_gradient
def check_gradient(x, grad_clip):
x_normed = tf.clip_by_norm(x, grad_clip)
grad_norm = tf.norm(grad_clip)
return x_normed, grad_norm
def count_parameters(model, c):
try:
return model.count_params()
except:
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
input_lengths = np.random.randint(100, 129, (8, ))
input_lengths[-1] = 128
input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
mel_spec = np.random.rand(8, 2 * c.r,
c.audio['num_mels']).astype('float32')
mel_spec = tf.convert_to_tensor(mel_spec)
speaker_ids = np.random.randint(
0, 5, (8, )) if c.use_speaker_embedding else None
_ = model(input_dummy, input_lengths, mel_spec)
return model.count_params()
def setup_model(num_chars, num_speakers, c):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
raise NotImplemented(' [!] Tacotron model is not ready.')
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
return model

8
tf/utils/tf_utils.py Normal file
View File

@ -0,0 +1,8 @@
import tensorflow as tf
def shape_list(x):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]

View File

@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None):
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_range_expand = seq_range_expand.to(sequence_length.device)
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
# B x T_max

View File

@ -1,3 +1,7 @@
import pkg_resources
installed = {pkg.key for pkg in pkg_resources.working_set}
if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
import tensorflow as tf
import torch
import numpy as np
from .text import text_to_sequence, phoneme_to_sequence
@ -14,23 +18,32 @@ def text_to_seqvec(text, CONFIG, use_cuda):
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
# torch tensor
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
return chars_var.long()
return seq
def numpy_to_torch(np_array, dtype, cuda=False):
if np_array is None:
return None
tensor = torch.as_tensor(np_array, dtype=dtype)
if cuda:
return tensor.cuda()
return tensor
def numpy_to_tf(np_array, dtype):
if np_array is None:
return None
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
return tensor
def compute_style_mel(style_wav, ap, use_cuda):
print(style_wav)
style_mel = torch.FloatTensor(ap.melspectrogram(
ap.load_wav(style_wav))).unsqueeze(0)
if use_cuda:
return style_mel.cuda()
style_mel = ap.melspectrogram(
ap.load_wav(style_wav)).expand_dims(0)
return style_mel
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id)
@ -44,11 +57,31 @@ def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None)
return decoder_output, postnet_output, alignments, stop_tokens
def parse_outputs(postnet_output, decoder_output, alignments):
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst:
raise NotImplemented(' [!] GST inference not implemented for TF')
if truncated:
raise NotImplemented(' [!] Truncated inference not implemented for TF')
# TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model(
inputs)
return decoder_output, postnet_output, alignments, stop_tokens
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
return postnet_output, decoder_output, alignment
stop_tokens = stop_tokens[0].cpu().numpy()
return postnet_output, decoder_output, alignment, stop_tokens
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].numpy()
decoder_output = decoder_output[0].numpy()
alignment = alignments[0].numpy()
stop_tokens = stop_tokens[0].numpy()
return postnet_output, decoder_output, alignment, stop_tokens
def trim_silence(wav, ap):
@ -98,7 +131,8 @@ def synthesis(model,
truncated=False,
enable_eos_bos_chars=False, #pylint: disable=unused-argument
use_griffin_lim=False,
do_trim_silence=False):
do_trim_silence=False,
backend='torch'):
"""Synthesize voice for the given text.
Args:
@ -114,6 +148,7 @@ def synthesis(model,
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
do_trim_silence (bool): trim silence after synthesis.
backend (str): tf or torch
"""
# GST processing
style_mel = None
@ -121,15 +156,29 @@ def synthesis(model,
style_mel = compute_style_mel(style_wav, ap, use_cuda)
# preprocess the given text
inputs = text_to_seqvec(text, CONFIG, use_cuda)
speaker_id = id_to_torch(speaker_id)
if speaker_id is not None and use_cuda:
speaker_id = speaker_id.cuda()
# pass tensors to backend
if backend == 'torch':
speaker_id = id_to_torch(speaker_id)
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
inputs = inputs.unsqueeze(0)
else:
# TODO: handle speaker id for tf model
style_mel = numpy_to_tf(style_mel, tf.float32)
inputs = numpy_to_tf(inputs, tf.int32)
inputs = tf.expand_dims(inputs, 0)
# synthesize voice
decoder_output, postnet_output, alignments, stop_tokens = run_model(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
if backend == 'torch':
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
postnet_output, decoder_output, alignments, stop_tokens)
else:
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens)
# convert outputs to numpy
postnet_output, decoder_output, alignment = parse_outputs(
postnet_output, decoder_output, alignments)
# plot results
wav = None
if use_griffin_lim:

View File

@ -61,7 +61,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
plt.yticks(range(len(text)), list(text))
plt.colorbar()
# plot stopnet predictions
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
plt.subplot(num_plot, 1, 2)
plt.plot(range(len(stop_tokens)), list(stop_tokens))
# plot postnet spectrogram