mirror of https://github.com/coqui-ai/TTS.git
Added support for Tacotron2 GST + abbility to condition style input with wav or tokens
parent
fe081d4f7c
commit
84b7ab6ee6
|
@ -27,7 +27,6 @@ class Tacotron2(TacotronAbstract):
|
|||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
|
@ -42,33 +41,18 @@ class Tacotron2(TacotronAbstract):
|
|||
ddc_r, gst)
|
||||
|
||||
# init layer dims
|
||||
decoder_in_features = 512
|
||||
encoder_in_features = 512
|
||||
|
||||
if speaker_embedding_dim is None:
|
||||
# if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim
|
||||
self.embeddings_per_sample = False
|
||||
speaker_embedding_dim = 512
|
||||
else:
|
||||
# if speaker_embedding_dim is not None we need use speaker embedding per sample
|
||||
self.embeddings_per_sample = True
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if num_speakers > 1:
|
||||
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
||||
if self.gst:
|
||||
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
||||
|
||||
# embedding layer
|
||||
speaker_embedding_dim = 512 if num_speakers > 1 else 0
|
||||
gst_embedding_dim = gst_embedding_dim if self.gst else 0
|
||||
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# base layers
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
|
||||
# speaker embedding layer
|
||||
if num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
# base model layers
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
|
@ -99,7 +83,7 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
@ -108,18 +92,20 @@ class Tacotron2(TacotronAbstract):
|
|||
# B x T_in_max x D_en
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
||||
|
@ -147,18 +133,24 @@ class Tacotron2(TacotronAbstract):
|
|||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
def inference(self, text, speaker_ids=None, style_mel=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
|
@ -168,21 +160,27 @@ class Tacotron2(TacotronAbstract):
|
|||
decoder_outputs, postnet_outputs, alignments)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
|
||||
"""
|
||||
Preserve model states for continuous inference
|
||||
"""
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||
encoder_outputs)
|
||||
|
|
|
@ -165,7 +165,6 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input):
|
||||
""" Compute global style token """
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
|
@ -176,11 +175,17 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
for k_token in range(self.gst_style_tokens):
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
gst_outputs = self.gst_layer(style_input)
|
||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||
return inputs, embedded_gst
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
|
|
|
@ -210,10 +210,13 @@ def synthesis(model,
|
|||
if backend == 'torch':
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
<<<<<<< HEAD:mozilla_voice_tts/tts/utils/synthesis.py
|
||||
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda)
|
||||
|
||||
=======
|
||||
>>>>>>> Added support for Tacotron2 GST + abbility to condition style input with wav or tokens:utils/synthesis.py
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
import json
|
||||
import string
|
||||
|
||||
from TTS.utils.synthesis import synthesis
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.text.symbols import make_symbols, symbols, phonemes
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def tts(model,
|
||||
vocoder_model,
|
||||
C,
|
||||
VC,
|
||||
text,
|
||||
ap,
|
||||
ap_vocoder,
|
||||
use_cuda,
|
||||
batched_vocoder,
|
||||
speaker_id=None,
|
||||
figures=False):
|
||||
t_1 = time.time()
|
||||
use_vocoder_model = vocoder_model is not None
|
||||
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
|
||||
model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'],
|
||||
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
|
||||
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
|
||||
|
||||
if C.model == "Tacotron" and use_vocoder_model:
|
||||
postnet_output = ap.out_linear_to_mel(postnet_output.T).T
|
||||
# correct if there is a scale difference b/w two models
|
||||
if use_vocoder_model:
|
||||
postnet_output = ap._denormalize(postnet_output)
|
||||
postnet_output = ap_vocoder._normalize(postnet_output)
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
waveform = vocoder_model.generate(
|
||||
vocoder_input.cuda() if use_cuda else vocoder_input,
|
||||
batched=batched_vocoder,
|
||||
target=8000,
|
||||
overlap=400)
|
||||
print(" > Run-time: {}".format(time.time() - t_1))
|
||||
return alignment, postnet_output, stop_tokens, waveform
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
global symbols, phonemes
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to generate speech.')
|
||||
parser.add_argument('config_path',
|
||||
type=str,
|
||||
help='Path to model config file.')
|
||||
parser.add_argument(
|
||||
'model_path',
|
||||
type=str,
|
||||
help='Path to model file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'out_path',
|
||||
type=str,
|
||||
help='Path to save final wav file. Wav file will be names as the text given.',
|
||||
)
|
||||
parser.add_argument('--use_cuda',
|
||||
type=bool,
|
||||
help='Run model on CUDA.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--vocoder_path',
|
||||
type=str,
|
||||
help=
|
||||
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
|
||||
default="",
|
||||
)
|
||||
parser.add_argument('--vocoder_config_path',
|
||||
type=str,
|
||||
help='Path to vocoder model config file.',
|
||||
default="")
|
||||
parser.add_argument(
|
||||
'--batched_vocoder',
|
||||
type=bool,
|
||||
help="If True, vocoder model uses faster batch processing.",
|
||||
default=True)
|
||||
parser.add_argument('--speakers_json',
|
||||
type=str,
|
||||
help="JSON file for multi-speaker model.",
|
||||
default="")
|
||||
parser.add_argument(
|
||||
'--speaker_id',
|
||||
type=int,
|
||||
help="target speaker_id if the model is multi-speaker.",
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.vocoder_path != "":
|
||||
assert args.use_cuda, " [!] Enable cuda for vocoder."
|
||||
from WaveRNN.models.wavernn import Model as VocoderModel
|
||||
|
||||
# load the config
|
||||
C = load_config(args.config_path)
|
||||
C.forward_attn_mask = True
|
||||
|
||||
# load the audio processor
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'characters' in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load speakers
|
||||
if args.speakers_json != '':
|
||||
speakers = json.load(open(args.speakers_json, 'r'))
|
||||
num_speakers = len(speakers)
|
||||
else:
|
||||
num_speakers = 0
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, C)
|
||||
cp = torch.load(args.model_path)
|
||||
model.load_state_dict(cp['model'])
|
||||
model.eval()
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
model.decoder.set_r(cp['r'])
|
||||
|
||||
# load vocoder model
|
||||
if args.vocoder_path != "":
|
||||
VC = load_config(args.vocoder_config_path)
|
||||
ap_vocoder = AudioProcessor(**VC.audio)
|
||||
bits = 10
|
||||
vocoder_model = VocoderModel(rnn_dims=512,
|
||||
fc_dims=512,
|
||||
mode=VC.mode,
|
||||
mulaw=VC.mulaw,
|
||||
pad=VC.pad,
|
||||
upsample_factors=VC.upsample_factors,
|
||||
feat_dims=VC.audio["num_mels"],
|
||||
compute_dims=128,
|
||||
res_out_dims=128,
|
||||
res_blocks=10,
|
||||
hop_length=ap.hop_length,
|
||||
sample_rate=ap.sample_rate,
|
||||
use_aux_net=True,
|
||||
use_upsample_net=True)
|
||||
|
||||
check = torch.load(args.vocoder_path)
|
||||
vocoder_model.load_state_dict(check['model'])
|
||||
vocoder_model.eval()
|
||||
if args.use_cuda:
|
||||
vocoder_model.cuda()
|
||||
else:
|
||||
vocoder_model = None
|
||||
VC = None
|
||||
ap_vocoder = None
|
||||
|
||||
# synthesize voice
|
||||
print(" > Text: {}".format(args.text))
|
||||
_, _, _, wav = tts(model,
|
||||
vocoder_model,
|
||||
C,
|
||||
VC,
|
||||
args.text,
|
||||
ap,
|
||||
ap_vocoder,
|
||||
args.use_cuda,
|
||||
args.batched_vocoder,
|
||||
speaker_id=args.speaker_id,
|
||||
figures=False)
|
||||
|
||||
# save the results
|
||||
file_name = args.text.replace(" ", "_")
|
||||
file_name = file_name.translate(
|
||||
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
|
||||
out_path = os.path.join(args.out_path, file_name)
|
||||
print(" > Saving output to {}".format(out_path))
|
||||
ap.save_wav(wav, out_path)
|
|
@ -1,3 +1,4 @@
|
|||
<<<<<<< HEAD:tests/inputs/test_train_config.json
|
||||
{
|
||||
"model": "Tacotron2",
|
||||
"run_name": "test_sample_dataset_run",
|
||||
|
@ -150,3 +151,161 @@
|
|||
|
||||
}
|
||||
|
||||
=======
|
||||
{
|
||||
"model": "Tacotron2",
|
||||
"run_name": "ljspeech-ddc-bn",
|
||||
"run_description": "tacotron2 with ddc and batch-normalization",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// stft parameters
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 20,
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "~",
|
||||
// "bos": "^",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": false, // use noam warmup and lr schedule.
|
||||
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
// TACOTRON PRENET
|
||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||
"prenet_type": "bn", // "original" or "bn".
|
||||
"prenet_dropout": false, // enable/disable dropout at prenet.
|
||||
|
||||
// TACOTRON ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"use_forward_attn": false, // if it uses forward attention. In general, it aligns faster.
|
||||
"forward_attn_mask": false, // Additional masking forcing monotonicity only in eval mode.
|
||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||
|
||||
// STOPNET
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step:": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"use_gst": true, // use global style tokens
|
||||
"gst": { // gst parameter if gst is enabled
|
||||
"gst_style_input": null, // Condition the style input either on a
|
||||
// -> wave file [path to wave] or
|
||||
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
|
||||
// with the dictionary being len(dict) == len(gst_style_tokens).
|
||||
"gst_embedding_dim": 512,
|
||||
"gst_num_heads": 4,
|
||||
"gst_style_tokens": 10
|
||||
},
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
>>>>>>> Added support for Tacotron2 GST + abbility to condition style input with wav or tokens:config.json
|
||||
|
|
|
@ -0,0 +1,374 @@
|
|||
import os
|
||||
import glob
|
||||
import torch
|
||||
import shutil
|
||||
import datetime
|
||||
import subprocess
|
||||
import importlib
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n")
|
||||
if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
return current
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
# try:
|
||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
# 'HEAD']) # Verify client is clean
|
||||
# except:
|
||||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(
|
||||
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except subprocess.CalledProcessError:
|
||||
commit = "0000000"
|
||||
print(' > Git Hash: {}'.format(commit))
|
||||
return commit
|
||||
|
||||
|
||||
def create_experiment_folder(root_path, model_name, debug):
|
||||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
if debug:
|
||||
commit_hash = 'debug'
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(
|
||||
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
||||
|
||||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||
if not checkpoint_files:
|
||||
if os.path.exists(experiment_path):
|
||||
shutil.rmtree(experiment_path, ignore_errors=True)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
is_multi_speaker = False
|
||||
speakers = [item[-1] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
|
||||
len(items) * 0.01)
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
# most stupid code ever -- Fix it !
|
||||
while len(items_eval) < eval_split_size:
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
if speaker_counter[items[item_idx][-1]] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
del items[item_idx]
|
||||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint_state, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint_state.items() if k in model_dict
|
||||
}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
# 3. skip reinit layers
|
||||
if c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if reinit_layer_name not in k
|
||||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||
len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
return model
|
||||
|
||||
class KeepAverage():
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.avg_values[key]
|
||||
|
||||
def items(self):
|
||||
return self.avg_values.items()
|
||||
|
||||
def add_value(self, name, init_val=0, init_iter=0):
|
||||
self.avg_values[name] = init_val
|
||||
self.iters[name] = init_iter
|
||||
|
||||
def update_value(self, name, value, weighted_avg=False):
|
||||
if name not in self.avg_values:
|
||||
# add value if not exist before
|
||||
self.add_value(name, init_val=value)
|
||||
else:
|
||||
# else update existing value
|
||||
if weighted_avg:
|
||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * \
|
||||
self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
def add_values(self, name_dict):
|
||||
for key, value in name_dict.items():
|
||||
self.add_value(key, init_val=value)
|
||||
|
||||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if restricted:
|
||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||
if name in c.keys():
|
||||
if max_val:
|
||||
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
||||
if min_val:
|
||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||
if enum_list:
|
||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||
if val_type:
|
||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
|
||||
|
||||
def check_config(c):
|
||||
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
_check_argument('run_name', c, restricted=True, val_type=str)
|
||||
_check_argument('run_description', c, val_type=str)
|
||||
|
||||
# AUDIO
|
||||
_check_argument('audio', c, restricted=True, val_type=dict)
|
||||
|
||||
# audio processing parameters
|
||||
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
_check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# vocabulary parameters
|
||||
_check_argument('characters', c, restricted=False, val_type=dict)
|
||||
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
|
||||
# normalization parameters
|
||||
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
_check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100)
|
||||
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
_check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# validation parameters
|
||||
_check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
|
||||
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
|
||||
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
|
||||
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
|
||||
|
||||
# attention
|
||||
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
|
||||
_check_argument('attention_heads', c, restricted=True, val_type=int)
|
||||
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
_check_argument('windowing', c, restricted=True, val_type=bool)
|
||||
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# stopnet
|
||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
|
||||
|
||||
# tensorboard
|
||||
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from TTS.utils.text import cleaners
|
||||
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
|
||||
# paths
|
||||
_check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker
|
||||
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
|
||||
# GST
|
||||
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||
_check_argument('gst_style_input', c, restricted=True, val_type=str)
|
||||
_check_argument('gst', c, restricted=True, val_type=dict)
|
||||
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
_check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
_check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
|
||||
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
Loading…
Reference in New Issue