mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev'
commit
e2d9b0e27c
2
.compute
2
.compute
|
@ -2,4 +2,4 @@
|
|||
source ../tmp/venv/bin/activate
|
||||
# python extract_features.py --data_path ${DATA_ROOT}/shared/data/keithito/LJSpeech-1.1/ --cache_path ~/tts_cache/ --config config.json --num_proc 12 --dataset ljspeech --meta_file metadata.csv --val_split 1000 --process_audio true
|
||||
# python train.py --config_path config.json --data_path ~/tts_cache/ --debug true
|
||||
python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true
|
||||
python train.py --config_path config.json --data_path ${DATA_ROOT}/shared/data/Blizzard/Nancy/ --debug true
|
||||
|
|
14
README.md
14
README.md
|
@ -75,7 +75,8 @@ Example datasets, we successfully applied TTS, are linked below.
|
|||
|
||||
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
|
||||
- [Nancy](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/)
|
||||
- [TWEB](http://https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)\
|
||||
- [TWEB](http://https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)
|
||||
- [M-AI-Labs](http://www.caito.de/2019/01/the-m-ailabs-speech-dataset/)
|
||||
|
||||
## Training and Fine-tuning LJ-Speech
|
||||
[Click Here](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) for hands-on **Notebook example**, training LJSpeech.
|
||||
|
@ -107,11 +108,7 @@ In case of any error or intercepted execution, if there is no checkpoint yet und
|
|||
You can also enjoy Tensorboard, if you point the Tensorboard argument```--logdir``` to the experiment folder.
|
||||
|
||||
## Testing
|
||||
Best way to test your pre-trained network is to use Notebooks under ```notebooks``` folder.
|
||||
|
||||
## Logging
|
||||
# TODO
|
||||
TTS enables intense logging on Tensorboard.
|
||||
Best way to test your network is to use Notebooks under ```notebooks``` folder.
|
||||
|
||||
## What is new with TTS
|
||||
If you train TTS with LJSpeech dataset, you start to hear reasonable results after 12.5K iterations with batch size 32. This is the fastest training with character-based methods up to our knowledge. Out implementation is also quite robust against long sentences.
|
||||
|
@ -121,8 +118,11 @@ If you train TTS with LJSpeech dataset, you start to hear reasonable results aft
|
|||
- Weight decay ([ref](http://www.fast.ai/2018/07/02/adam-weight-decay/)). After a certain point of the training, you might observe the model over-fitting. That is, the model is able to pronounce words probably better but the quality of the speech quality gets lower and sometimes attention alignment gets disoriented.
|
||||
- Stop token prediction with an additional module. The original Tacotron model does not propose a stop token to stop the decoding process. Therefore, you need to use heuristic measures to stop the decoder. Here, we prefer to use additional layers at the end to decide when to stop.
|
||||
- Applying sigmoid to the model outputs. Since the output values are expected to be in the range [0, 1], we apply sigmoid to make things easier to approximate the expected output distribution.
|
||||
- Phoneme based training is enabled for easier learning and robust pronunciation. It also makes easier to adapt TTS to the most languages without worrying about language specific characters.
|
||||
- Configurable attention windowing at inference-time for robust alignment. It enforces network to only consider a certain window of encoder steps per iteration.
|
||||
- Detailed Tensorboard stats for activation, weight and gradient values per layer. It is useful to detect defects and compare networks.
|
||||
|
||||
One common question is to ask why we don't use Tacotron2 architecture. According to our ablation experiments, nothing, except Location Sensitive Attention, improves the performance, given the big increase in the model size.
|
||||
One common question is to ask why we don't use Tacotron2 architecture. According to our ablation experiments, nothing, except Location Sensitive Attention, improves the performance, given the increase in the model size.
|
||||
|
||||
Please feel free to offer new changes and pull things off. We are happy to discuss and make things better.
|
||||
|
||||
|
|
56
config.json
56
config.json
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"model_name": "TTS-dev-tweb",
|
||||
"model_description": "Higher dropout rate for stopnet and disabled custom initialization, pull current mel prediction to stopnet.",
|
||||
"model_name": "queue",
|
||||
"model_description": "Queue memory and change lower r incrementatlly",
|
||||
|
||||
"audio":{
|
||||
"audio_processor": "audio", // to use dictate different audio processors, if available.
|
||||
|
@ -10,7 +10,7 @@
|
|||
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
||||
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
|
@ -25,30 +25,36 @@
|
|||
"do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
},
|
||||
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
"epochs": 1000,
|
||||
"lr": 0.001,
|
||||
"lr_decay": false,
|
||||
"warmup_steps": 4000,
|
||||
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"loss_weight": 0.0, // loss weight to emphasize lower frequencies. Lower frequencies are in general more important for speech signals.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
|
||||
"batch_size": 20,
|
||||
"eval_batch_size":32,
|
||||
"r": 5,
|
||||
"wd": 0.000001,
|
||||
"checkpoint": true,
|
||||
"save_step": 5000,
|
||||
"print_step": 10,
|
||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":32,
|
||||
"r": 2, // Number of frames to predict for step.
|
||||
"wd": 0.00001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 5000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 50, // Number of steps to log traning on console.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"run_eval": true,
|
||||
"data_path": "../../Data/LJSpeech-1.1/", // can overwritten from command argument
|
||||
"meta_file_train": "transcript_train.txt", // metafile for training dataloader.
|
||||
"meta_file_val": "transcript_val.txt", // metafile for evaluation dataloader.
|
||||
"dataset": "tweb", // one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 0, // minimum text length to use in training
|
||||
"max_seq_len": 300, // maximum text length
|
||||
"output_path": "/media/erogol/data_ssd/Data/models/tweb_models/", // output path for all training outputs.
|
||||
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 300, // DATASET-RELATED: maximum text length
|
||||
"output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs.
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4 // number of evaluation data loader processes.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"phoneme_cache_path": "ljspeech_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us" // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import random
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.text import text_to_sequence, phoneme_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
@ -22,7 +22,10 @@ class MyDataset(Dataset):
|
|||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
cached=False):
|
||||
cached=False,
|
||||
use_phonemes=True,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us"):
|
||||
"""
|
||||
Args:
|
||||
root_path (str): root path for the data folder.
|
||||
|
@ -40,6 +43,10 @@ class MyDataset(Dataset):
|
|||
max_seq_len (int): (float("inf")) maximum sequence length.
|
||||
cached (bool): (false) true if the given data path is created
|
||||
by extract_features.py.
|
||||
use_phonemes (bool): (true) if true, text converted to phonemes.
|
||||
phoneme_cache_path (str): path to cache phoneme features.
|
||||
phoneme_language (str): one the languages from
|
||||
https://github.com/bootphon/phonemizer#languages
|
||||
"""
|
||||
self.root_path = root_path
|
||||
self.batch_group_size = batch_group_size
|
||||
|
@ -51,8 +58,16 @@ class MyDataset(Dataset):
|
|||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.cached = cached
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
if not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path)
|
||||
print(" > DataLoader initialization")
|
||||
print(" | > Data path: {}".format(root_path))
|
||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||
if use_phonemes:
|
||||
print(" | > phoneme language: {}".format(phoneme_language))
|
||||
print(" | > Cached dataset: {}".format(self.cached))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
|
||||
|
@ -69,27 +84,42 @@ class MyDataset(Dataset):
|
|||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
||||
def load_phoneme_sequence(self, wav_file, text):
|
||||
file_name = os.path.basename(wav_file).split('.')[0]
|
||||
tmp_path = os.path.join(self.phoneme_cache_path, file_name+'_phoneme.npy')
|
||||
if os.path.isfile(tmp_path):
|
||||
text = np.load(tmp_path)
|
||||
else:
|
||||
text = np.asarray(
|
||||
phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language), dtype=np.int32)
|
||||
np.save(tmp_path, text)
|
||||
return text
|
||||
|
||||
def load_data(self, idx):
|
||||
if self.cached:
|
||||
wav_name = self.items[idx][1]
|
||||
mel_name = self.items[idx][2]
|
||||
linear_name = self.items[idx][3]
|
||||
text = self.items[idx][0]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
|
||||
if wav_name.split('.')[-1] == 'npy':
|
||||
wav = self.load_np(wav_name)
|
||||
else:
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
|
||||
else:
|
||||
text, wav_file = self.items[idx]
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
mel = None
|
||||
linear = None
|
||||
|
||||
if self.use_phonemes:
|
||||
text = self.load_phoneme_sequence(wav_file, text)
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1]}
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.items[idx][1], 'mel':mel, 'linear': linear}
|
||||
return sample
|
||||
|
||||
def sort_items(self):
|
||||
|
@ -148,12 +178,13 @@ class MyDataset(Dataset):
|
|||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
|
||||
if self.cached:
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
else:
|
||||
# if specs are not computed, compute them.
|
||||
if batch[0]['mel'] is None and batch[0]['linear'] is None:
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
else:
|
||||
mel = [d['mel'] for d in batch]
|
||||
linear = [d['linear'] for d in batch]
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
|
|
144
datasets/TWEB.py
144
datasets/TWEB.py
|
@ -1,144 +0,0 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.text import text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class TWEBDataset(Dataset):
|
||||
def __init__(self,
|
||||
csv_file,
|
||||
root_dir,
|
||||
outputs_per_step,
|
||||
sample_rate,
|
||||
text_cleaner,
|
||||
num_mels,
|
||||
min_level_db,
|
||||
frame_shift_ms,
|
||||
frame_length_ms,
|
||||
preemphasis,
|
||||
ref_level_db,
|
||||
num_freq,
|
||||
power,
|
||||
min_seq_len=0):
|
||||
|
||||
with open(csv_file, "r") as f:
|
||||
self.frames = [line.split('\t') for line in f]
|
||||
self.root_dir = root_dir
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
|
||||
frame_shift_ms, frame_length_ms, preemphasis,
|
||||
ref_level_db, num_freq, power)
|
||||
print(" > Reading TWEB from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
self.frames = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
|
@ -42,6 +42,28 @@ def tweb(root_path, meta_file):
|
|||
# return {'text': texts, 'wavs': wavs}
|
||||
|
||||
|
||||
def mailabs(root_path, meta_files):
|
||||
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
||||
folders = [os.path.dirname(f.strip()) for f in meta_files.split(",")]
|
||||
meta_files = [f.strip() for f in meta_files.split(",")]
|
||||
items = []
|
||||
for idx, meta_file in enumerate(meta_files):
|
||||
print(" | > {}".format(meta_file))
|
||||
folder = folders[idx]
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
with open(txt_file, 'r') as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, folder, 'wavs', cols[0]+'.wav')
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1]
|
||||
items.append([text, wav_file])
|
||||
else:
|
||||
continue
|
||||
random.shuffle(items)
|
||||
return items
|
||||
|
||||
|
||||
def ljspeech(root_path, meta_file):
|
||||
"""Normalizes the Nancy meta data file to TTS format"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
|
|
|
@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
|
|||
|
||||
|
||||
class AttentionRNNCell(nn.Module):
|
||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
|
||||
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
|
||||
r"""
|
||||
General Attention RNN wrapper
|
||||
|
||||
|
@ -110,10 +110,17 @@ class AttentionRNNCell(nn.Module):
|
|||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
windowing (bool): attention windowing forcing monotonic attention.
|
||||
It is only active in eval mode.
|
||||
"""
|
||||
super(AttentionRNNCell, self).__init__()
|
||||
self.align_model = align_model
|
||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||
self.windowing = windowing
|
||||
if self.windowing:
|
||||
self.win_back = 3
|
||||
self.win_front = 6
|
||||
self.win_idx = None
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||
|
@ -138,6 +145,7 @@ class AttentionRNNCell(nn.Module):
|
|||
"""
|
||||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
self.win_idx = 0
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||
|
@ -151,6 +159,16 @@ class AttentionRNNCell(nn.Module):
|
|||
if mask is not None:
|
||||
mask = mask.view(memory.size(0), -1)
|
||||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Windowing
|
||||
if not self.training and self.windowing:
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
alignment[:, :back_win] = -float("inf")
|
||||
if front_win < memory.shape[1]:
|
||||
alignment[:, front_win:] = -float("inf")
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||
# Normalize context weight
|
||||
# alignment = F.softmax(alignment, dim=-1)
|
||||
# alignment = 5 * alignment
|
||||
|
|
|
@ -22,14 +22,13 @@ class Prenet(nn.Module):
|
|||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
for layer in self.layers:
|
||||
torch.nn.init.xavier_uniform_(
|
||||
layer.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
layer.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
|
||||
def forward(self, inputs):
|
||||
for linear in self.layers:
|
||||
|
@ -88,8 +87,7 @@ class BatchNormConv1d(nn.Module):
|
|||
else:
|
||||
raise RuntimeError('Unknown activation function')
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv1d.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_gain))
|
||||
self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.padder(x)
|
||||
|
@ -113,11 +111,9 @@ class Highway(nn.Module):
|
|||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.H.weight,
|
||||
gain=torch.nn.init.calculate_gain('relu'))
|
||||
self.H.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.T.weight,
|
||||
gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
|
||||
def forward(self, inputs):
|
||||
H = self.relu(self.H(inputs))
|
||||
|
@ -302,23 +298,26 @@ class Decoder(nn.Module):
|
|||
in_features (int): input vector (encoder output) sample size.
|
||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||
r (int): number of outputs per time step.
|
||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.max_decoder_steps = 500
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
self.prenet = Prenet(memory_dim * self.memory_size, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
rnn_dim=256,
|
||||
annot_dim=in_features,
|
||||
memory_dim=128,
|
||||
align_model='ls')
|
||||
align_model='ls',
|
||||
windowing=attn_windowing)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -326,6 +325,10 @@ class Decoder(nn.Module):
|
|||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
# learn init values instead of zero init.
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
# self.init_layers()
|
||||
|
||||
|
@ -338,6 +341,9 @@ class Decoder(nn.Module):
|
|||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def _reshape_memory(self, memory):
|
||||
"""
|
||||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
B = memory.shape[0]
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
|
@ -347,6 +353,28 @@ class Decoder(nn.Module):
|
|||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs):
|
||||
"""
|
||||
Initialization of decoder states
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
# decoder states
|
||||
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
||||
decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx]*B).long())
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
|
||||
current_context_vec, attention, attention_cum)
|
||||
|
||||
def forward(self, inputs, memory=None, mask=None):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
@ -365,36 +393,28 @@ class Decoder(nn.Module):
|
|||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# Run greedy decoding if memory is None
|
||||
greedy = not self.training
|
||||
if memory is not None:
|
||||
memory = self._reshape_memory(memory)
|
||||
T_decoder = memory.size(0)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [
|
||||
inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
outputs = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
memory_input = initial_memory
|
||||
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\
|
||||
current_context_vec, attention, attention_cum = self._init_states(inputs)
|
||||
while True:
|
||||
if t > 0:
|
||||
if memory is None:
|
||||
memory_input = outputs[-1]
|
||||
new_memory = outputs[-1]
|
||||
else:
|
||||
memory_input = memory[t - 1]
|
||||
new_memory = memory[t - 1]
|
||||
# Queuing if memory size defined else use previous prediction only.
|
||||
if self.memory_size > 0:
|
||||
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
|
||||
else:
|
||||
memory_input = new_memory
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
|
@ -412,7 +432,7 @@ class Decoder(nn.Module):
|
|||
for idx in range(len(self.decoder_rnns)):
|
||||
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
decoder_input, decoder_rnn_hiddens[idx])
|
||||
# Residual connectinon
|
||||
# Residual connection
|
||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
del decoder_input
|
||||
|
@ -433,7 +453,8 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1] / 4 and stop_token > 0.6:
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or
|
||||
attention[:, -1].item() > 0.6):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -459,8 +480,7 @@ class StopNet(nn.Module):
|
|||
self.linear = nn.Linear(in_features, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
|
|
|
@ -1,27 +1,30 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from utils.text.symbols import symbols
|
||||
from math import sqrt
|
||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
embedding_dim=256,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
r=5,
|
||||
padding_idx=None):
|
||||
padding_idx=None,
|
||||
memory_size=5,
|
||||
attn_windowing=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(
|
||||
len(symbols), embedding_dim, padding_idx=padding_idx)
|
||||
print(" | > Number of characters : {}".format(len(symbols)))
|
||||
num_chars, embedding_dim, padding_idx=padding_idx)
|
||||
print(" | > Number of characters : {}".format(num_chars))
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(embedding_dim)
|
||||
self.decoder = Decoder(256, mel_dim, r)
|
||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -11,3 +11,4 @@ flask
|
|||
scipy==0.19.0
|
||||
lws
|
||||
tqdm
|
||||
git+git://github.com/bootphon/phonemizer@master
|
||||
|
|
4
setup.py
4
setup.py
|
@ -82,6 +82,10 @@ setup(
|
|||
"flask",
|
||||
# "lws",
|
||||
"tqdm",
|
||||
"phonemizer",
|
||||
],
|
||||
dependency_links=[
|
||||
'http://github.com/bootphon/phonemizer/tarball/master#egg=phonemizer'
|
||||
],
|
||||
extras_require={
|
||||
"bin": [
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
print("Python is running!!")
|
38
train.py
38
train.py
|
@ -17,6 +17,7 @@ from utils.generic_utils import (
|
|||
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
||||
get_commit_hash, sequence_mask, NoamLR)
|
||||
from utils.text.symbols import symbols, phonemes
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
|
@ -46,7 +47,11 @@ def setup_loader(is_val=False):
|
|||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||
min_seq_len=0 if is_val else c.min_seq_len,
|
||||
max_seq_len=float("inf") if is_val else c.max_seq_len,
|
||||
cached=False if c.dataset != "tts_cache" else True)
|
||||
cached=False if c.dataset != "tts_cache" else True,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||
|
@ -121,8 +126,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
|||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\
|
||||
+ c.loss_weight * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
|
@ -351,7 +356,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
|||
|
||||
|
||||
def main(args):
|
||||
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = Tacotron(num_chars, c.embedding_size, ap.num_freq, ap.num_mels, c.r, c.memory_size)
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
@ -361,28 +367,39 @@ def main(args):
|
|||
criterion = L1LossMasked()
|
||||
criterion_st = nn.BCELoss()
|
||||
|
||||
partial_init_flag = False
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except:
|
||||
print(" > Partial model initialization.")
|
||||
partial_init_flag = True
|
||||
model_dict = model.state_dict()
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint['model'].items() if k in model_dict
|
||||
for k, v in checkpoint['model'].items() if k in model_dict
|
||||
}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint['model'].items() if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
# 3. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
# 3. load the new state dict
|
||||
# 4. load the new state dict
|
||||
model.load_state_dict(model_dict)
|
||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if not partial_init_flag:
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(
|
||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
start_epoch = checkpoint['epoch']
|
||||
|
@ -423,7 +440,10 @@ def main(args):
|
|||
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||
train_loss, val_loss),
|
||||
flush=True)
|
||||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||
target_loss = train_loss
|
||||
if c.run_eval:
|
||||
target_loss = val_loss
|
||||
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
|
||||
OUT_PATH, current_step, epoch)
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
import librosa
|
||||
import torch
|
||||
import numpy as np
|
||||
from .text import text_to_sequence
|
||||
from .text import text_to_sequence, phoneme_to_sequence
|
||||
from .visual import visualize
|
||||
from matplotlib import pylab as plt
|
||||
|
||||
|
@ -11,11 +11,19 @@ from matplotlib import pylab as plt
|
|||
def synthesis(m, s, CONFIG, use_cuda, ap):
|
||||
""" Given the text, synthesising the audio """
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||
# print(phoneme_to_sequence(s, text_cleaner))s
|
||||
# print(sequence_to_phoneme(phoneme_to_sequence(s, text_cleaner)))
|
||||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language),
|
||||
dtype=np.int32)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32)
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_spec, linear_spec, alignments, stop_tokens = m.forward(chars_var.long())
|
||||
mel_spec, linear_spec, alignments, stop_tokens = m.forward(
|
||||
chars_var.long())
|
||||
linear_spec = linear_spec[0].data.cpu().numpy()
|
||||
mel_spec = mel_spec[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
|
|
|
@ -1,16 +1,71 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import re
|
||||
import phonemizer
|
||||
from phonemizer.phonemize import phonemize
|
||||
from utils.text import cleaners
|
||||
from utils.text.symbols import symbols
|
||||
from utils.text.symbols import symbols, phonemes, _phoneme_punctuations
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
# Regular expression matchinf punctuations, ignoring empty space
|
||||
pat = r'['+_phoneme_punctuations[:-1]+']+'
|
||||
|
||||
|
||||
def text2phone(text, language):
|
||||
'''
|
||||
Convert graphemes to phonemes.
|
||||
'''
|
||||
seperator = phonemizer.separator.Separator(' |', '', '|')
|
||||
#try:
|
||||
punctuations = re.findall(pat, text)
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||
# Replace \n with matching punctuations.
|
||||
if len(punctuations) > 0:
|
||||
for punct in punctuations[:-1]:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
try:
|
||||
ph = ph[:-1] + punctuations[-1]
|
||||
except:
|
||||
print(text)
|
||||
return ph
|
||||
|
||||
|
||||
def phoneme_to_sequence(text, cleaner_names, language):
|
||||
'''
|
||||
TODO: This ignores punctuations
|
||||
'''
|
||||
sequence = []
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
phonemes = text2phone(clean_text, language)
|
||||
# print(phonemes.replace('|', ''))
|
||||
if phonemes is None:
|
||||
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
||||
for phoneme in phonemes.split('|'):
|
||||
# print(word, ' -- ', phonemes_text)
|
||||
sequence += _phoneme_to_sequence(phoneme)
|
||||
# Aeepnd EOS char
|
||||
sequence.append(_phonemes_to_id['~'])
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_phoneme(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_phonemes:
|
||||
s = _id_to_phonemes[symbol_id]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
@ -69,9 +124,17 @@ def _symbols_to_sequence(symbols):
|
|||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _phoneme_to_sequence(phonemes):
|
||||
return [_phonemes_to_id[s] for s in list(phonemes) if _should_keep_phoneme(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s is not '_' and s is not '~'
|
||||
|
||||
|
||||
def _should_keep_phoneme(p):
|
||||
return p in _phonemes_to_id and p is not '_' and p is not '~'
|
||||
|
|
|
@ -12,7 +12,7 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use
|
|||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
from .number_norm import normalize_numbers
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
@ -86,3 +86,12 @@ def english_cleaners(text):
|
|||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def phoneme_cleaners(text):
|
||||
'''Pipeline for phonemes mode, including number and abbreviation expansion.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
|
|
@ -2,18 +2,16 @@
|
|||
|
||||
import re
|
||||
|
||||
valid_symbols = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||
'Y', 'Z', 'ZH'
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
# valid_symbols = [
|
||||
# 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||
# 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||
# 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||
# 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||
# 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||
# 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||
# 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||
# 'Y', 'Z', 'ZH'
|
||||
# ]
|
||||
|
||||
|
||||
class CMUDict:
|
||||
|
@ -39,6 +37,20 @@ class CMUDict:
|
|||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
def get_arpabet(self, word, cmudict, punctuation_symbols):
|
||||
first_symbol, last_symbol = '', ''
|
||||
if len(word) > 0 and word[0] in punctuation_symbols:
|
||||
first_symbol = word[0]
|
||||
word = word[1:]
|
||||
if len(word) > 0 and word[-1] in punctuation_symbols:
|
||||
last_symbol = word[-1]
|
||||
word = word[:-1]
|
||||
arpabet = cmudict.lookup(word)
|
||||
if arpabet is not None:
|
||||
return first_symbol + '{%s}' % arpabet[0] + last_symbol
|
||||
else:
|
||||
return first_symbol + word + last_symbol
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
|
||||
|
|
|
@ -10,12 +10,25 @@ from utils.text import cmudict
|
|||
_pad = '_'
|
||||
_eos = '~'
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
||||
_punctuations = '!\'(),-.:;? '
|
||||
_phoneme_punctuations = '.!;:,?'
|
||||
|
||||
# TODO: include more phoneme characters for other languages.
|
||||
_phonemes = ['l','ɹ','ɜ','ɚ','k','u','ʔ','ð','ɐ','ɾ','ɑ','ɔ','b','ɛ','t','v','n','m','ʊ','ŋ','s',
|
||||
'ʌ','o','ʃ','i','p','æ','e','a','ʒ',' ','h','ɪ','ɡ','f','r','w','ɫ','ɬ','d','x','ː',
|
||||
'ᵻ','ə','j','θ','z','ɒ']
|
||||
|
||||
_phonemes = sorted(list(set(_phonemes)))
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
_arpabet = ['@' + s for s in _phonemes]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
||||
phonemes = [_pad, _eos] + list(_phonemes) + list(_punctuations)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(" > TTS symbols ")
|
||||
print(symbols)
|
||||
print(" > TTS phonemes ")
|
||||
print(phonemes)
|
||||
|
|
Loading…
Reference in New Issue