diff --git a/mozilla_voice_tts/bin/synthesize.py b/mozilla_voice_tts/bin/synthesize.py index 6f139433..7d68aef3 100644 --- a/mozilla_voice_tts/bin/synthesize.py +++ b/mozilla_voice_tts/bin/synthesize.py @@ -18,9 +18,9 @@ from mozilla_voice_tts.utils.io import load_config from mozilla_voice_tts.vocoder.utils.generic_utils import setup_generator -def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_id): +def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_fileid, speaker_embedding=None, gst_style=None): t_1 = time.time() - waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, use_gl) + waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_fileid, gst_style, False, CONFIG.enable_eos_bos_chars, use_gl, speaker_embedding=speaker_embedding) if CONFIG.model == "Tacotron" and not use_gl: mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T if not use_gl: @@ -80,10 +80,15 @@ if __name__ == "__main__": help="JSON file for multi-speaker model.", default="") parser.add_argument( - '--speaker_id', - type=int, - help="target speaker_id if the model is multi-speaker.", + '--speaker_fileid', + type=str, + help="if CONFIG.use_external_speaker_embedding_file is true, name of speaker embedding reference file present in speakers.json, else target speaker_fileid if the model is multi-speaker.", default=None) + parser.add_argument( + '--gst_style', + help="Wav path file for GST stylereference.", + default=None) + args = parser.parse_args() # load the config @@ -97,16 +102,24 @@ if __name__ == "__main__": if 'characters' in C.keys(): symbols, phonemes = make_symbols(**C.characters) + speaker_embedding = None + speaker_embedding_dim = None + num_speakers = 0 + # load speakers if args.speakers_json != '': - speakers = json.load(open(args.speakers_json, 'r')) - num_speakers = len(speakers) - else: - num_speakers = 0 + speaker_mapping = json.load(open(args.speakers_json, 'r')) + num_speakers = len(speaker_mapping) + if C.use_external_speaker_embedding_file: + if args.speaker_fileid is not None: + speaker_embedding = speaker_mapping[args.speaker_fileid]['embedding'] + else: # if speaker_fileid is not specificated use the first sample in speakers.json + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'] + speaker_embedding_dim = len(speaker_embedding) # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) - model = setup_model(num_chars, num_speakers, C) + model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim) cp = torch.load(args.model_path, map_location=torch.device('cpu')) model.load_state_dict(cp['model']) model.eval() @@ -130,7 +143,27 @@ if __name__ == "__main__": # synthesize voice use_griffin_lim = args.vocoder_path == "" print(" > Text: {}".format(args.text)) - wav = tts(model, vocoder_model, args.text, C, args.use_cuda, ap, use_griffin_lim, args.speaker_id) + + if not C.use_external_speaker_embedding_file: + if args.speaker_fileid.isdigit(): + args.speaker_fileid = int(args.speaker_fileid) + else: + args.speaker_fileid = None + else: + args.speaker_fileid = None + + if args.gst_style is None: + gst_style = C.gst['gst_style_input'] + else: + # check if gst_style string is a dict, if is dict convert else use string + try: + gst_style = json.loads(args.gst_style) + if max(map(int, gst_style.keys())) >= C.gst['gst_style_tokens']: + raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), C.gst['gst_style_tokens'])) + except ValueError: + gst_style = args.gst_style + + wav = tts(model, vocoder_model, args.text, C, args.use_cuda, ap, use_griffin_lim, args.speaker_fileid, speaker_embedding=speaker_embedding, gst_style=gst_style) # save the results file_name = args.text.replace(" ", "_") diff --git a/mozilla_voice_tts/bin/train_encoder.py b/mozilla_voice_tts/bin/train_encoder.py index d612ac6e..f9bfea7f 100644 --- a/mozilla_voice_tts/bin/train_encoder.py +++ b/mozilla_voice_tts/bin/train_encoder.py @@ -10,21 +10,21 @@ import traceback import torch from torch.utils.data import DataLoader -from mozilla_voice_tts.generic_utils import count_parameters from mozilla_voice_tts.speaker_encoder.dataset import MyDataset from mozilla_voice_tts.speaker_encoder.generic_utils import save_best_model -from mozilla_voice_tts.speaker_encoder.loss import GE2ELoss +from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder from mozilla_voice_tts.speaker_encoder.visual import plot_embeddings from mozilla_voice_tts.tts.datasets.preprocess import load_meta_data -from mozilla_voice_tts.tts.utils.audio import AudioProcessor from mozilla_voice_tts.tts.utils.generic_utils import ( create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) from mozilla_voice_tts.tts.utils.io import copy_config_file, load_config -from mozilla_voice_tts.tts.utils.radam import RAdam -from mozilla_voice_tts.tts.utils.tensorboard_logger import TensorboardLogger -from mozilla_voice_tts.tts.utils.training import NoamLR, check_update +from mozilla_voice_tts.utils.audio import AudioProcessor +from mozilla_voice_tts.utils.generic_utils import count_parameters +from mozilla_voice_tts.utils.radam import RAdam +from mozilla_voice_tts.utils.tensorboard_logger import TensorboardLogger +from mozilla_voice_tts.utils.training import NoamLR, check_update torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True @@ -100,7 +100,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats train_stats = { - "GE2Eloss": avg_loss, + "loss": avg_loss, "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time @@ -135,12 +135,18 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_eval ap = AudioProcessor(**c.audio) - model = SpeakerEncoder(input_dim=40, - proj_dim=128, - lstm_dim=384, - num_lstm_layers=3) + model = SpeakerEncoder(input_dim=c.model['input_dim'], + proj_dim=c.model['proj_dim'], + lstm_dim=c.model['lstm_dim'], + num_lstm_layers=c.model['num_lstm_layers']) optimizer = RAdam(model.parameters(), lr=c.lr) - criterion = GE2ELoss(loss_method='softmax') + + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method='softmax') + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + else: + raise Exception("The %s not is a loss supported" % c.loss) if args.restore_path: checkpoint = torch.load(args.restore_path) @@ -242,7 +248,7 @@ if __name__ == '__main__': new_fields) LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR) + tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder') try: main(args) diff --git a/mozilla_voice_tts/bin/train_tts.py b/mozilla_voice_tts/bin/train_tts.py index 719b926f..2b6cbfd0 100644 --- a/mozilla_voice_tts/bin/train_tts.py +++ b/mozilla_voice_tts/bin/train_tts.py @@ -49,7 +49,7 @@ from mozilla_voice_tts.utils.training import (NoamLR, adam_weight_decay, use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False): +def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): if is_val and not c.run_eval: loader = None else: @@ -68,7 +68,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - verbose=verbose) + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -82,9 +83,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): pin_memory=False) return loader - -def format_data(data): - if c.use_speaker_embedding: +def format_data(data, speaker_mapping=None): + if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file: speaker_mapping = load_speaker_mapping(OUT_PATH) # setup input data @@ -99,13 +99,20 @@ def format_data(data): avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] - speaker_ids = torch.LongTensor(speaker_ids) + if c.use_external_speaker_embedding_file: + speaker_embeddings = data[8] + speaker_ids = None + else: + speaker_ids = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_ids = torch.LongTensor(speaker_ids) + speaker_embeddings = None else: + speaker_embeddings = None speaker_ids = None + # set stop targets view, we predict a single stop token per iteration. stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) @@ -122,13 +129,16 @@ def format_data(data): stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length + if speaker_embeddings is not None: + speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) + + return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length def train(model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch, amp): + ap, global_step, epoch, amp, speaker_mapping=None): data_loader = setup_loader(ap, model.decoder.r, is_val=False, - verbose=(epoch == 0)) + verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -143,7 +153,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping) loader_time = time.time() - end_time global_step += 1 @@ -158,10 +168,10 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_backward_output = None alignments_backward = None @@ -312,8 +322,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, model.decoder.r, is_val=True) +def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): + data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping) model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -323,16 +333,16 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping) assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_backward_output = None alignments_backward = None @@ -494,22 +504,41 @@ def main(args): # pylint: disable=redefined-outer-name if c.use_speaker_embedding: speakers = get_speakers(meta_data_train) if args.restore_path: - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - else: + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + if not speaker_mapping: + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + if not speaker_mapping: + raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + speaker_embedding_dim = None + assert all([speaker in speaker_mapping + for speaker in speakers]), "As of now you, you cannot " \ + "introduce new speakers to " \ + "a previously trained model." + elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + else: # if start new train and don't use External Embedding file speaker_mapping = {name: i for i, name in enumerate(speakers)} + speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) print("Training with {} speakers: {}".format(num_speakers, ", ".join(speakers))) else: num_speakers = 0 + speaker_embedding_dim = None + speaker_mapping = None - model = setup_model(num_chars, num_speakers, c) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) params = set_weight_decay(model, c.wd) optimizer = RAdam(params, lr=c.lr, weight_decay=0) @@ -544,6 +573,8 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Partial model initialization.") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) + # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt')) + # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt')) model.load_state_dict(model_dict) del model_dict @@ -592,7 +623,7 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Number of output frames:", model.decoder.r) train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch, amp) + global_step, epoch, amp, speaker_mapping) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] diff --git a/mozilla_voice_tts/speaker_encoder/config.json b/mozilla_voice_tts/speaker_encoder/config.json index 0d0f8f68..11da0cf6 100644 --- a/mozilla_voice_tts/speaker_encoder/config.json +++ b/mozilla_voice_tts/speaker_encoder/config.json @@ -1,26 +1,33 @@ + { - "run_name": "libritts_360-half", - "run_description": "train speaker encoder for libritts 360", - "audio": { + "run_name": "Model compatible to CorentinJ/Real-Time-Voice-Cloning", + "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", + "audio":{ // Audio processing parameters - "num_mels": 40, // size of the mel spec frame. - "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. - "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "frame_length_ms": 50, // stft window length in ms. - "frame_shift_ms": 12.5, // stft window hop-lengh in ms. - "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "min_level_db": -100, // normalization range - "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "num_mels": 40, // size of the mel spec frame. + "fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "win_length": 400, // stft window length in ms. + "hop_length": 160, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] + "signal_norm": true, // normalize the spec values in range [0, 1] "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "do_trim_silence": false // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. }, "reinit_layers": [], + "loss": "ge2e", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) "grad_clip": 3.0, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. @@ -29,29 +36,24 @@ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "steps_plot_stats": 10, // number of steps to plot embeddings. "num_speakers_in_batch": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. "print_step": 1, // Number of steps to log traning on console. - "output_path": "/media/erogol/data_ssd/Models/libri_tts/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. - "num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "output_path": "../../checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. "model": { "input_dim": 40, - "proj_dim": 128, - "lstm_dim": 384, - "num_lstm_layers": 3 + "proj_dim": 256, + "lstm_dim": 256, + "num_lstm_layers": 3, + "use_lstm_with_projection": false }, "datasets": [ { - "name": "libri_tts", - "path": "/home/erogol/Data/Libri-TTS/train-clean-360/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "libri_tts", - "path": "/home/erogol/Data/Libri-TTS/train-clean-100/", + "name": "vctk", + "path": "../../../datasets/VCTK-Corpus-removed-silence/", "meta_file_train": null, "meta_file_val": null } diff --git a/mozilla_voice_tts/speaker_encoder/dataset.py b/mozilla_voice_tts/speaker_encoder/dataset.py index 42c75dd9..d3243c13 100644 --- a/mozilla_voice_tts/speaker_encoder/dataset.py +++ b/mozilla_voice_tts/speaker_encoder/dataset.py @@ -31,7 +31,7 @@ class MyDataset(Dataset): print(f" | > Num speakers: {len(self.speakers)}") def load_wav(self, filename): - audio = self.ap.load_wav(filename) + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) return audio def load_data(self, idx): diff --git a/mozilla_voice_tts/speaker_encoder/generic_utils.py b/mozilla_voice_tts/speaker_encoder/generic_utils.py index f649ceb9..bc72c91c 100644 --- a/mozilla_voice_tts/speaker_encoder/generic_utils.py +++ b/mozilla_voice_tts/speaker_encoder/generic_utils.py @@ -15,7 +15,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, 'optimizer': optimizer.state_dict() if optimizer is not None else None, 'step': current_step, 'epoch': epoch, - 'GE2Eloss': model_loss, + 'loss': model_loss, 'date': datetime.date.today().strftime("%B %d, %Y"), } torch.save(state, checkpoint_path) @@ -29,7 +29,7 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, 'model': new_state_dict, 'optimizer': optimizer.state_dict(), 'step': current_step, - 'GE2Eloss': model_loss, + 'loss': model_loss, 'date': datetime.date.today().strftime("%B %d, %Y"), } best_loss = model_loss diff --git a/mozilla_voice_tts/speaker_encoder/loss.py b/mozilla_voice_tts/speaker_encoder/losses.py similarity index 72% rename from mozilla_voice_tts/speaker_encoder/loss.py rename to mozilla_voice_tts/speaker_encoder/losses.py index ab290547..35ff73fa 100644 --- a/mozilla_voice_tts/speaker_encoder/loss.py +++ b/mozilla_voice_tts/speaker_encoder/losses.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - +import numpy as np # adapted from https://github.com/cvqluu/GE2E-Loss class GE2ELoss(nn.Module): @@ -23,6 +23,8 @@ class GE2ELoss(nn.Module): self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method + print(' > Initialised Generalized End-to-End loss') + assert self.loss_method in ["softmax", "contrast"] if self.loss_method == "softmax": @@ -119,3 +121,40 @@ class GE2ELoss(nn.Module): cos_sim_matrix = self.w * cos_sim_matrix + self.b L = self.embed_loss(dvecs, cos_sim_matrix) return L.mean() + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + def __init__(self, init_w=10.0, init_b=-5.0): + super(AngleProtoLoss, self).__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print(' > Initialised Angular Prototypical loss') + + def forward(self, x): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2)) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L diff --git a/mozilla_voice_tts/speaker_encoder/model.py b/mozilla_voice_tts/speaker_encoder/model.py index ca2abe31..df0527bc 100644 --- a/mozilla_voice_tts/speaker_encoder/model.py +++ b/mozilla_voice_tts/speaker_encoder/model.py @@ -16,15 +16,33 @@ class LSTMWithProjection(nn.Module): o, (_, _) = self.lstm(x) return self.linear(o) +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, + hidden_size=lstm_dim, + num_layers=num_lstm_layers, + batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) class SpeakerEncoder(nn.Module): - def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3): + def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection layers = [] - layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) - for _ in range(num_lstm_layers - 1): - layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) - self.layers = nn.Sequential(*layers) + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + self._init_layers() def _init_layers(self): @@ -37,12 +55,18 @@ class SpeakerEncoder(nn.Module): def forward(self, x): # TODO: implement state passing for lstms d = self.layers(x) - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) return d def inference(self, x): d = self.layers.forward(x) - d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) return d def compute_embedding(self, x, num_frames=160, overlap=0.5): diff --git a/mozilla_voice_tts/tts/configs/config.json b/mozilla_voice_tts/tts/configs/config.json index cd4595b9..2a61ba03 100644 --- a/mozilla_voice_tts/tts/configs/config.json +++ b/mozilla_voice_tts/tts/configs/config.json @@ -123,28 +123,37 @@ "max_seq_len": 153, // DATASET-RELATED: maximum text length // PATHS - "output_path": "/home/erogol/Models/LJSpeech/", + "output_path": "../../Mozilla-TTS/vctk-test/", // PHONEMES - "phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder. + "phoneme_cache_path": "../../Mozilla-TTS/vctk-test/", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages // MULTI-SPEAKER and GST - "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. - "style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference. - "use_gst": false, // TACOTRON ONLY: use global style tokens + "use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning. + "use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 + "external_speaker_embedding_file": "../../speakers-vctk-en.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 + "use_gst": true, // use global style tokens + "gst": { // gst parameter if gst is enabled + "gst_style_input": null, // Condition the style input either on a + // -> wave file [path to wave] or + // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} + // with the dictionary being len(dict) <= len(gst_style_tokens). + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + }, // DATASETS "datasets": // List of datasets. They all merged and they get different speaker_ids. [ { - "name": "ljspeech", - "path": "/home/erogol/Data/LJSpeech-1.1/", - "meta_file_train": "metadata.csv", + "name": "vctk", + "path": "../../../datasets/VCTK-Corpus-removed-silence/", + "meta_file_train": ["p225", "p234", "p238", "p245", "p248", "p261", "p294", "p302", "p326", "p335", "p347"], // for vtck if list, ignore speakers id in list for train, its useful for test cloning with new speakers "meta_file_val": null } ] - } diff --git a/mozilla_voice_tts/tts/datasets/TTSDataset.py b/mozilla_voice_tts/tts/datasets/TTSDataset.py index ac524e55..1ecca75f 100644 --- a/mozilla_voice_tts/tts/datasets/TTSDataset.py +++ b/mozilla_voice_tts/tts/datasets/TTSDataset.py @@ -24,6 +24,7 @@ class MyDataset(Dataset): phoneme_cache_path=None, phoneme_language="en-us", enable_eos_bos=False, + speaker_mapping=None, verbose=False): """ Args: @@ -58,6 +59,7 @@ class MyDataset(Dataset): self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language self.enable_eos_bos = enable_eos_bos + self.speaker_mapping = speaker_mapping self.verbose = verbose if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) @@ -127,7 +129,8 @@ class MyDataset(Dataset): 'text': text, 'wav': wav, 'item_idx': self.items[idx][1], - 'speaker_name': speaker_name + 'speaker_name': speaker_name, + 'wav_file_name': os.path.basename(wav_file) } return sample @@ -191,9 +194,15 @@ class MyDataset(Dataset): batch[idx]['item_idx'] for idx in ids_sorted_decreasing ] text = [batch[idx]['text'] for idx in ids_sorted_decreasing] + speaker_name = [batch[idx]['speaker_name'] for idx in ids_sorted_decreasing] - + # get speaker embeddings + if self.speaker_mapping is not None: + wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing] + speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names] + else: + speaker_embedding = None # compute features mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] @@ -224,6 +233,9 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) + if speaker_embedding is not None: + speaker_embedding = torch.FloatTensor(speaker_embedding) + # compute linear spectrogram if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype('float32') for w in wav] @@ -234,7 +246,7 @@ class MyDataset(Dataset): else: linear = None return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs + stop_targets, item_idxs, speaker_embedding raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/mozilla_voice_tts/tts/datasets/preprocess.py b/mozilla_voice_tts/tts/datasets/preprocess.py index c3cf34e5..ece3bcb6 100644 --- a/mozilla_voice_tts/tts/datasets/preprocess.py +++ b/mozilla_voice_tts/tts/datasets/preprocess.py @@ -93,9 +93,10 @@ def mozilla_de(root_path, meta_file): def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" - speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") + speaker_regex = re.compile( + "by_book/(male|female)/(?P[^/]+)/") if meta_files is None: - csv_files = glob(root_path+"/**/metadata.csv", recursive=True) + csv_files = glob(root_path + "/**/metadata.csv", recursive=True) else: csv_files = meta_files # meta_files = [f.strip() for f in meta_files.split(",")] @@ -115,12 +116,15 @@ def mailabs(root_path, meta_files=None): if meta_files is None: wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav') else: - wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav') + wav_file = os.path.join(root_path, + folder.replace("metadata.csv", ""), + 'wavs', cols[0] + '.wav') if os.path.isfile(wav_file): text = cols[1].strip() items.append([text, wav_file, speaker_name]) else: - raise RuntimeError("> File %s does not exist!"%(wav_file)) + raise RuntimeError("> File %s does not exist!" % + (wav_file)) return items @@ -185,7 +189,8 @@ def libri_tts(root_path, meta_files=None): text = cols[1] items.append([text, wav_file, speaker_name]) for item in items: - assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" + assert os.path.exists( + item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -197,7 +202,8 @@ def custom_turkish(root_path, meta_file): with open(txt_file, 'r', encoding='utf-8') as ttf: for line in ttf: cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', cols[0].strip() + '.wav') + wav_file = os.path.join(root_path, 'wavs', + cols[0].strip() + '.wav') if not os.path.exists(wav_file): skipped_files.append(wav_file) continue @@ -205,3 +211,44 @@ def custom_turkish(root_path, meta_file): items.append([text, wav_file, speaker_name]) print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items + + +# ToDo: add the dataset link when the dataset is released publicly +def brspeech(root_path, meta_file): + '''BRSpeech 3.0 beta''' + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, 'r') as ttf: + for line in ttf: + if line.startswith("wav_filename"): + continue + cols = line.split('|') + #print(cols) + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] + speaker_name = cols[3] + items.append([text, wav_file, speaker_name]) + return items + + +def vctk(root_path, meta_files=None, wavs_path='wav48'): + """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + test_speakers = meta_files + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", + recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, + root_path).split(os.sep) + file_id = txt_file.split('.')[0] + if isinstance(test_speakers, + list): # if is list ignore this speakers ids + if speaker_id in test_speakers: + continue + with open(meta_file) as file_text: + text = file_text.readlines()[0] + wav_file = os.path.join(root_path, wavs_path, speaker_id, + file_id + '.wav') + items.append([text, wav_file, speaker_id]) + + return items \ No newline at end of file diff --git a/mozilla_voice_tts/tts/layers/gst_layers.py b/mozilla_voice_tts/tts/layers/gst_layers.py index 01f90697..a49b14a2 100644 --- a/mozilla_voice_tts/tts/layers/gst_layers.py +++ b/mozilla_voice_tts/tts/layers/gst_layers.py @@ -96,7 +96,7 @@ class StyleTokenLayer(nn.Module): self.key_dim = embedding_dim // num_heads self.style_tokens = nn.Parameter( torch.FloatTensor(num_style_tokens, self.key_dim)) - nn.init.orthogonal_(self.style_tokens) + nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( query_dim=self.query_dim, key_dim=self.key_dim, diff --git a/mozilla_voice_tts/tts/layers/tacotron.py b/mozilla_voice_tts/tts/layers/tacotron.py index 2fc9e86a..807282b3 100644 --- a/mozilla_voice_tts/tts/layers/tacotron.py +++ b/mozilla_voice_tts/tts/layers/tacotron.py @@ -291,7 +291,7 @@ class Decoder(nn.Module): def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet, speaker_embedding_dim): + separate_stopnet): super(Decoder, self).__init__() self.r_init = r self.r = r @@ -303,7 +303,7 @@ class Decoder(nn.Module): self.separate_stopnet = separate_stopnet self.query_dim = 256 # memory -> |Prenet| -> processed_memory - prenet_dim = frame_channels * self.memory_size + speaker_embedding_dim if self.use_memory_queue else frame_channels + speaker_embedding_dim + prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels self.prenet = Prenet( prenet_dim, prenet_type, @@ -429,7 +429,7 @@ class Decoder(nn.Module): # assert new_memory.shape[-1] == self.r * self.frame_channels self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):] - def forward(self, inputs, memory, mask, speaker_embeddings=None): + def forward(self, inputs, memory, mask): """ Args: inputs: Encoder outputs. @@ -454,8 +454,7 @@ class Decoder(nn.Module): if t > 0: new_memory = memory[t - 1] self._update_memory_input(new_memory) - if speaker_embeddings is not None: - self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) + output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] @@ -463,15 +462,12 @@ class Decoder(nn.Module): t += 1 return self._parse_outputs(outputs, attentions, stop_tokens) - def inference(self, inputs, speaker_embeddings=None): + def inference(self, inputs): """ Args: inputs: encoder outputs. - speaker_embeddings: speaker vectors. - Shapes: - - inputs: (B, T, D_out_enc) - - speaker_embeddings: (B, D_embed) + - inputs: batch x time x encoder_out_dim """ outputs = [] attentions = [] @@ -484,8 +480,6 @@ class Decoder(nn.Module): if t > 0: new_memory = outputs[-1] self._update_memory_input(new_memory) - if speaker_embeddings is not None: - self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) output, stop_token, attention = self.decode(inputs, None) stop_token = torch.sigmoid(stop_token.data) outputs += [output] diff --git a/mozilla_voice_tts/tts/layers/tacotron2.py b/mozilla_voice_tts/tts/layers/tacotron2.py index 395a10ea..490f3728 100644 --- a/mozilla_voice_tts/tts/layers/tacotron2.py +++ b/mozilla_voice_tts/tts/layers/tacotron2.py @@ -141,14 +141,12 @@ class Decoder(nn.Module): location_attn (bool): if true, use location sensitive attention. attn_K (int): number of attention heads for GravesAttention. separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. - speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training. """ # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, attn_K, separate_stopnet, - speaker_embedding_dim): + forward_attn_mask, location_attn, attn_K, separate_stopnet): super(Decoder, self).__init__() self.frame_channels = frame_channels self.r_init = r @@ -157,7 +155,6 @@ class Decoder(nn.Module): self.separate_stopnet = separate_stopnet self.max_decoder_steps = 1000 self.stop_threshold = 0.5 - self.speaker_embedding_dim = speaker_embedding_dim # model dimensions self.query_dim = 1024 @@ -300,7 +297,7 @@ class Decoder(nn.Module): decoder_output = decoder_output[:, :self.r * self.frame_channels] return decoder_output, self.attention.attention_weights, stop_token - def forward(self, inputs, memories, mask, speaker_embeddings=None): + def forward(self, inputs, memories, mask): r"""Train Decoder with teacher forcing. Args: inputs: Encoder outputs. @@ -318,8 +315,6 @@ class Decoder(nn.Module): memories = self._reshape_memory(memories) memories = torch.cat((memory, memories), dim=0) memories = self._update_memory(memories) - if speaker_embeddings is not None: - memories = torch.cat([memories, speaker_embeddings], dim=-1) memories = self.prenet(memories) self._init_states(inputs, mask=mask) @@ -337,16 +332,14 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens - def inference(self, inputs, speaker_embeddings=None): + def inference(self, inputs): r"""Decoder inference without teacher forcing and use Stopnet to stop decoder. Args: inputs: Encoder outputs. - speaker_embeddings: speaker embedding vectors. Shapes: - inputs: (B, T, D_out_enc) - - speaker_embeddings: (B, D_embed) - outputs: (B, T_mel, D_mel) - alignments: (B, T_in, T_out) - stop_tokens: (B, T_out) @@ -360,8 +353,6 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments, t = [], [], [], 0 while True: memory = self.prenet(memory) - if speaker_embeddings is not None: - memory = torch.cat([memory, speaker_embeddings], dim=-1) decoder_output, alignment, stop_token = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) outputs += [decoder_output.squeeze(1)] diff --git a/mozilla_voice_tts/tts/models/tacotron.py b/mozilla_voice_tts/tts/models/tacotron.py index 295dbeda..1dcf2fc8 100644 --- a/mozilla_voice_tts/tts/models/tacotron.py +++ b/mozilla_voice_tts/tts/models/tacotron.py @@ -28,7 +28,13 @@ class Tacotron(TacotronAbstract): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + encoder_in_features=256, + decoder_in_features=256, + speaker_embedding_dim=None, gst=False, + gst_embedding_dim=256, + gst_num_heads=4, + gst_style_tokens=10, memory_size=5): super(Tacotron, self).__init__(num_chars, num_speakers, r, postnet_output_dim, @@ -37,37 +43,41 @@ class Tacotron(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, gst) - decoder_in_features = 512 if num_speakers > 1 else 256 - encoder_in_features = 512 if num_speakers > 1 else 256 - speaker_embedding_dim = 256 - proj_speaker_dim = 80 if num_speakers > 1 else 0 - # base model layers + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, + gst_num_heads, gst_style_tokens) + + # speaker embedding layers + if self.num_speakers > 1: + if not self.embeddings_per_sample: + speaker_embedding_dim = 256 + self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + # speaker and gst embeddings is concat in decoder input + if self.num_speakers > 1: + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + + # embedding layer self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(encoder_in_features) - self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet, proj_speaker_dim) + attn_K, separate_stopnet) self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) - # speaker embedding layers - if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) - self.speaker_project_mel = nn.Sequential( - nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh()) - self.speaker_embeddings = None - self.speaker_embeddings_projected = None + # global style token layers if self.gst: - gst_embedding_dim = 256 self.gst_layer = GST(num_mel=80, - num_heads=4, - num_style_tokens=10, + num_heads=gst_num_heads, + num_style_tokens=gst_style_tokens, embedding_dim=gst_embedding_dim) # backward pass decoder if self.bidirectional_decoder: @@ -75,13 +85,12 @@ class Tacotron(TacotronAbstract): # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - decoder_in_features, decoder_output_dim, ddc_r, memory_size, + self.decoder_in_features, decoder_output_dim, ddc_r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet, proj_speaker_dim) + attn_K, separate_stopnet) - - def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None): + def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): """ Shapes: - characters: B x T_in @@ -89,17 +98,9 @@ class Tacotron(TacotronAbstract): - mel_specs: B x T_out x D - speaker_ids: B x 1 """ - self._init_states() input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim inputs = self.embedding(characters) - # B x speaker_embed_dim - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - # B x T_in x embed_dim + speaker_embed_dim - inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) # B x T_in x encoder_in_features encoder_outputs = self.encoder(inputs) # sequence masking @@ -108,15 +109,20 @@ class Tacotron(TacotronAbstract): if self.gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding( - encoder_outputs, self.speaker_embeddings) + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) # decoder_outputs: B x decoder_in_features x T_out # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask, - self.speaker_embeddings_projected) + encoder_outputs, mel_specs, input_mask) # sequence masking if output_mask is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -138,22 +144,22 @@ class Tacotron(TacotronAbstract): return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, characters, speaker_ids=None, style_mel=None): + def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None): inputs = self.embedding(characters) - self._init_states() - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) encoder_outputs = self.encoder(inputs) - if self.gst and style_mel is not None: + if self.gst: + # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, style_mel) if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding( - encoder_outputs, self.speaker_embeddings) + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs, self.speaker_embeddings_projected) + encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) diff --git a/mozilla_voice_tts/tts/models/tacotron2.py b/mozilla_voice_tts/tts/models/tacotron2.py index 327e1bd9..a9ba442c 100644 --- a/mozilla_voice_tts/tts/models/tacotron2.py +++ b/mozilla_voice_tts/tts/models/tacotron2.py @@ -5,7 +5,6 @@ from mozilla_voice_tts.tts.layers.gst_layers import GST from mozilla_voice_tts.tts.layers.tacotron2 import Decoder, Encoder, Postnet from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract - # TODO: match function arguments with tacotron class Tacotron2(TacotronAbstract): def __init__(self, @@ -28,7 +27,13 @@ class Tacotron2(TacotronAbstract): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, - gst=False): + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10): super(Tacotron2, self).__init__(num_chars, num_speakers, r, postnet_output_dim, decoder_output_dim, attn_type, attn_win, @@ -36,38 +41,48 @@ class Tacotron2(TacotronAbstract): forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, - ddc_r, gst) - decoder_in_features = 512 if num_speakers > 1 else 512 - encoder_in_features = 512 if num_speakers > 1 else 512 - proj_speaker_dim = 80 if num_speakers > 1 else 0 - # base layers + ddc_r, encoder_in_features, decoder_in_features, + speaker_embedding_dim, gst, gst_embedding_dim, + gst_num_heads, gst_style_tokens) + + # speaker embedding layer + if self.num_speakers > 1: + if not self.embeddings_per_sample: + speaker_embedding_dim = 512 + self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + # speaker and gst embeddings is concat in decoder input + if self.num_speakers > 1: + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + + # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) - if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, 512) - self.speaker_embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(encoder_in_features) - self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, proj_speaker_dim) + location_attn, attn_K, separate_stopnet) self.postnet = Postnet(self.postnet_output_dim) + # global style token layers if self.gst: - gst_embedding_dim = encoder_in_features self.gst_layer = GST(num_mel=80, - num_heads=4, - num_style_tokens=10, - embedding_dim=gst_embedding_dim) + num_heads=self.gst_num_heads, + num_style_tokens=self.gst_style_tokens, + embedding_dim=self.gst_embedding_dim) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, + self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet, proj_speaker_dim) + separate_stopnet) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -75,8 +90,7 @@ class Tacotron2(TacotronAbstract): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): - self._init_states() + def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): # compute mask for padding # B x T_in_max (boolean) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) @@ -84,20 +98,22 @@ class Tacotron2(TacotronAbstract): embedded_inputs = self.embedding(text).transpose(1, 2) # B x T_in_max x D_en encoder_outputs = self.encoder(embedded_inputs, text_lengths) - # adding speaker embeddding to encoder output - # TODO: multi-speaker - # B x speaker_embed_dim - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - # B x T_in x embed_dim + speaker_embed_dim - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) - encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) - # global style token + if self.gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.num_speakers > 1: + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r decoder_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, input_mask) @@ -122,14 +138,19 @@ class Tacotron2(TacotronAbstract): return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, text, speaker_ids=None): + def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) + + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + if self.num_speakers > 1: - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) + if not self.embeddings_per_sample: + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) @@ -138,14 +159,22 @@ class Tacotron2(TacotronAbstract): decoder_outputs, postnet_outputs, alignments) return decoder_outputs, postnet_outputs, alignments, stop_tokens - def inference_truncated(self, text, speaker_ids=None): + def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + + if self.num_speakers > 1: + if not self.embeddings_per_sample: + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) @@ -153,17 +182,3 @@ class Tacotron2(TacotronAbstract): mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - - - def _speaker_embedding_pass(self, encoder_outputs, speaker_ids): - # TODO: multi-speaker - # if hasattr(self, "speaker_embedding") and speaker_ids is None: - # raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - # if hasattr(self, "speaker_embedding") and speaker_ids is not None: - - # speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), - # encoder_outputs.size(1), - # -1) - # encoder_outputs = encoder_outputs + speaker_embeddings - # return encoder_outputs - pass diff --git a/mozilla_voice_tts/tts/models/tacotron_abstract.py b/mozilla_voice_tts/tts/models/tacotron_abstract.py index c9ae9b83..d98d03b7 100644 --- a/mozilla_voice_tts/tts/models/tacotron_abstract.py +++ b/mozilla_voice_tts/tts/models/tacotron_abstract.py @@ -28,7 +28,13 @@ class TacotronAbstract(ABC, nn.Module): bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, - gst=False): + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10): """ Abstract Tacotron class """ super().__init__() self.num_chars = num_chars @@ -36,6 +42,9 @@ class TacotronAbstract(ABC, nn.Module): self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim self.gst = gst + self.gst_embedding_dim = gst_embedding_dim + self.gst_num_heads = gst_num_heads + self.gst_style_tokens = gst_style_tokens self.num_speakers = num_speakers self.bidirectional_decoder = bidirectional_decoder self.double_decoder_consistency = double_decoder_consistency @@ -51,6 +60,9 @@ class TacotronAbstract(ABC, nn.Module): self.location_attn = location_attn self.attn_K = attn_K self.separate_stopnet = separate_stopnet + self.encoder_in_features = encoder_in_features + self.decoder_in_features = decoder_in_features + self.speaker_embedding_dim = speaker_embedding_dim # layers self.embedding = None @@ -58,8 +70,17 @@ class TacotronAbstract(ABC, nn.Module): self.decoder = None self.postnet = None + # multispeaker + if self.speaker_embedding_dim is None: + # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim + self.embeddings_per_sample = False + else: + # if speaker_embedding_dim is not None we need use speaker embedding per sample + self.embeddings_per_sample = True + # global style token if self.gst: + self.decoder_in_features += gst_embedding_dim # add gst embedding dim self.gst_layer = None # model states @@ -158,11 +179,22 @@ class TacotronAbstract(ABC, nn.Module): self.speaker_embeddings_projected = self.speaker_project_mel( self.speaker_embeddings).squeeze(1) - def compute_gst(self, inputs, mel_specs): + def compute_gst(self, inputs, style_input): """ Compute global style token """ - # pylint: disable=not-callable - gst_outputs = self.gst_layer(mel_specs) - inputs = self._add_speaker_embedding(inputs, gst_outputs) + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + else: + gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs @staticmethod diff --git a/mozilla_voice_tts/tts/utils/generic_utils.py b/mozilla_voice_tts/tts/utils/generic_utils.py index e98c267d..f0b718fa 100644 --- a/mozilla_voice_tts/tts/utils/generic_utils.py +++ b/mozilla_voice_tts/tts/utils/generic_utils.py @@ -44,7 +44,7 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand -def setup_model(num_chars, num_speakers, c): +def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('mozilla_voice_tts.tts.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) @@ -55,6 +55,9 @@ def setup_model(num_chars, num_speakers, c): postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], memory_size=c.memory_size, attn_type=c.attention_type, attn_win=c.windowing, @@ -69,7 +72,8 @@ def setup_model(num_chars, num_speakers, c): separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder, double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "tacotron2": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, @@ -77,6 +81,9 @@ def setup_model(num_chars, num_speakers, c): postnet_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, @@ -90,9 +97,11 @@ def setup_model(num_chars, num_speakers, c): separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder, double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim) return model + class KeepAverage(): def __init__(self): self.avg_values = {} @@ -168,7 +177,7 @@ def check_config(c): check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100) + check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) check_argument('trim_db', c['audio'], restricted=True, val_type=int) @@ -239,15 +248,21 @@ def check_config(c): # paths check_argument('output_path', c, restricted=True, val_type=str) - # multi-speaker gst + # multi-speaker and gst check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('style_wav_for_test', c, restricted=True, val_type=str) + check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool) + check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str) check_argument('use_gst', c, restricted=True, val_type=bool) + check_argument('gst', c, restricted=True, val_type=dict) + check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict]) + check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10) + check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000) # datasets - checking only the first entry check_argument('datasets', c, restricted=True, val_type=list) for dataset_entry in c['datasets']: check_argument('name', dataset_entry, restricted=True, val_type=str) check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/mozilla_voice_tts/tts/utils/speakers.py b/mozilla_voice_tts/tts/utils/speakers.py index ff624b36..156e42af 100644 --- a/mozilla_voice_tts/tts/utils/speakers.py +++ b/mozilla_voice_tts/tts/utils/speakers.py @@ -10,12 +10,15 @@ def make_speakers_json_path(out_path): def load_speaker_mapping(out_path): """Loads speaker mapping if already present.""" try: - with open(make_speakers_json_path(out_path)) as f: + if os.path.splitext(out_path)[1] == '.json': + json_file = out_path + else: + json_file = make_speakers_json_path(out_path) + with open(json_file) as f: return json.load(f) except FileNotFoundError: return {} - def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" speakers_json_path = make_speakers_json_path(out_path) diff --git a/mozilla_voice_tts/tts/utils/synthesis.py b/mozilla_voice_tts/tts/utils/synthesis.py index fef2348d..0952c936 100644 --- a/mozilla_voice_tts/tts/utils/synthesis.py +++ b/mozilla_voice_tts/tts/utils/synthesis.py @@ -37,23 +37,25 @@ def numpy_to_tf(np_array, dtype): return tensor -def compute_style_mel(style_wav, ap): - style_mel = ap.melspectrogram( - ap.load_wav(style_wav)).expand_dims(0) +def compute_style_mel(style_wav, ap, cuda=False): + style_mel = torch.FloatTensor(ap.melspectrogram( + ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) + if cuda: + return style_mel.cuda() return style_mel -def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): +def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id) + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) else: if truncated: decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id) + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id) + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) return decoder_output, postnet_output, alignments, stop_tokens @@ -129,13 +131,24 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(speaker_id): +def id_to_torch(speaker_id, cuda=False): if speaker_id is not None: speaker_id = np.asarray(speaker_id) speaker_id = torch.from_numpy(speaker_id).unsqueeze(0) + if cuda: + return speaker_id.cuda() return speaker_id +def embedding_to_torch(speaker_embedding, cuda=False): + if speaker_embedding is not None: + speaker_embedding = np.asarray(speaker_embedding) + speaker_embedding = torch.from_numpy(speaker_embedding).unsqueeze(0).type(torch.FloatTensor) + if cuda: + return speaker_embedding.cuda() + return speaker_embedding + + # TODO: perform GL with pytorch for batching def apply_griffin_lim(inputs, input_lens, CONFIG, ap): '''Apply griffin-lim to each sample iterating throught the first dimension. @@ -165,6 +178,7 @@ def synthesis(model, enable_eos_bos_chars=False, #pylint: disable=unused-argument use_griffin_lim=False, do_trim_silence=False, + speaker_embedding=None, backend='torch'): """Synthesize voice for the given text. @@ -185,14 +199,23 @@ def synthesis(model, """ # GST processing style_mel = None - if CONFIG.model == "TacotronGST" and style_wav is not None: - style_mel = compute_style_mel(style_wav, ap) + if CONFIG.use_gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) # preprocess the given text inputs = text_to_seqvec(text, CONFIG) # pass tensors to backend if backend == 'torch': - speaker_id = id_to_torch(speaker_id) - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + + if speaker_embedding is not None: + speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda) + + if not isinstance(style_mel, dict): + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) elif backend == 'tf': @@ -207,7 +230,7 @@ def synthesis(model, # synthesize voice if backend == 'torch': decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( - model, inputs, CONFIG, truncated, speaker_id, style_mel) + model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( postnet_output, decoder_output, alignments, stop_tokens) elif backend == 'tf': diff --git a/mozilla_voice_tts/tts/utils/text/cleaners.py b/mozilla_voice_tts/tts/utils/text/cleaners.py index f0a66f57..a36ebe67 100644 --- a/mozilla_voice_tts/tts/utils/text/cleaners.py +++ b/mozilla_voice_tts/tts/utils/text/cleaners.py @@ -67,15 +67,16 @@ def remove_aux_symbols(text): text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text) return text - -def replace_symbols(text): +def replace_symbols(text, lang='en'): text = text.replace(';', ',') text = text.replace('-', ' ') - text = text.replace(':', ',') - text = text.replace('&', 'and') + text = text.replace(':', ' ') + if lang == 'en': + text = text.replace('&', 'and') + elif lang == 'pt': + text = text.replace('&', ' e ') return text - def basic_cleaners(text): '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' text = lowercase(text) @@ -91,6 +92,13 @@ def transliteration_cleaners(text): return text +def basic_german_cleaners(text): + '''Pipeline for German text''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + # TODO: elaborate it def basic_turkish_cleaners(text): '''Pipeline for Turkish text''' @@ -99,7 +107,6 @@ def basic_turkish_cleaners(text): text = collapse_whitespace(text) return text - def english_cleaners(text): '''Pipeline for English text, including number and abbreviation expansion.''' text = convert_to_ascii(text) @@ -111,6 +118,14 @@ def english_cleaners(text): text = collapse_whitespace(text) return text +def portuguese_cleaners(text): + '''Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that''' + text = lowercase(text) + text = replace_symbols(text, lang='pt') + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text def phoneme_cleaners(text): '''Pipeline for phonemes mode, including number and abbreviation expansion.''' diff --git a/mozilla_voice_tts/utils/generic_utils.py b/mozilla_voice_tts/utils/generic_utils.py index 478b4358..dcfbbdc3 100644 --- a/mozilla_voice_tts/utils/generic_utils.py +++ b/mozilla_voice_tts/utils/generic_utils.py @@ -146,5 +146,11 @@ def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restrict assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' if enum_list: assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' - if val_type: + if isinstance(val_type, list): + is_valid = False + for typ in val_type: + if isinstance(c[name], typ): + is_valid = True + assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + elif val_type: assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' diff --git a/notebooks/AngleProto-Speaker_Encoder- ExtractSpeakerEmbeddings-by-sample.ipynb b/notebooks/AngleProto-Speaker_Encoder- ExtractSpeakerEmbeddings-by-sample.ipynb new file mode 100644 index 00000000..15206130 --- /dev/null +++ b/notebooks/AngleProto-Speaker_Encoder- ExtractSpeakerEmbeddings-by-sample.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a noteboook used to generate the speaker embeddings with the AngleProto speaker encoder model for multi-speaker training.\n", + "\n", + "Before running this script please DON'T FORGET: \n", + "- to set file paths.\n", + "- to download related model files from TTS.\n", + "- download or clone related repos, linked below.\n", + "- setup the repositories. ```python setup.py install```\n", + "- to checkout right commit versions (given next to the model) of TTS.\n", + "- to set the right paths in the cell below.\n", + "\n", + "Repository:\n", + "- TTS: https://github.com/mozilla/TTS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "import importlib\n", + "import random\n", + "import librosa\n", + "import torch\n", + "\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "from TTS.tts.utils.speakers import save_speaker_mapping, load_speaker_mapping\n", + "\n", + "# you may need to change this depending on your system\n", + "os.environ['CUDA_VISIBLE_DEVICES']='0'\n", + "\n", + "\n", + "from TTS.tts.utils.speakers import save_speaker_mapping, load_speaker_mapping\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.io import load_config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should also adjust all the path constants to point at the relevant locations for you locally" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_RUN_PATH = \"../../Mozilla-TTS/checkpoints/libritts_100+360-angleproto-June-06-2020_04+12PM-9c04d1f/\"\n", + "MODEL_PATH = MODEL_RUN_PATH + \"best_model.pth.tar\"\n", + "CONFIG_PATH = MODEL_RUN_PATH + \"config.json\"\n", + "\n", + "\n", + "DATASETS_NAME = ['vctk'] # list the datasets\n", + "DATASETS_PATH = ['../../../datasets/VCTK/']\n", + "DATASETS_METAFILE = ['']\n", + "\n", + "USE_CUDA = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Preprocess dataset\n", + "meta_data = []\n", + "for i in range(len(DATASETS_NAME)):\n", + " preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n", + " preprocessor = getattr(preprocessor, DATASETS_NAME[i].lower())\n", + " meta_data += preprocessor(DATASETS_PATH[i],DATASETS_METAFILE[i])\n", + " \n", + "meta_data= list(meta_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c = load_config(CONFIG_PATH)\n", + "ap = AudioProcessor(**c['audio'])\n", + "\n", + "model = SpeakerEncoder(**c.model)\n", + "model.load_state_dict(torch.load(MODEL_PATH)['model'])\n", + "model.eval()\n", + "if USE_CUDA:\n", + " model.cuda()\n", + "\n", + "embeddings_dict = {}\n", + "len_meta_data= len(meta_data)\n", + "\n", + "for i in tqdm(range(len_meta_data)):\n", + " _, wav_file, speaker_id = meta_data[i]\n", + " wav_file_name = os.path.basename(wav_file)\n", + " mel_spec = ap.melspectrogram(ap.load_wav(wav_file)).T\n", + " mel_spec = torch.FloatTensor(mel_spec[None, :, :])\n", + " if USE_CUDA:\n", + " mel_spec = mel_spec.cuda()\n", + " embedd = model.compute_embedding(mel_spec).cpu().detach().numpy().reshape(-1)\n", + " embeddings_dict[wav_file_name] = [embedd,speaker_id]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create and export speakers.json\n", + "speaker_mapping = {sample: {'name': embeddings_dict[sample][1], 'embedding':embeddings_dict[sample][0].reshape(-1).tolist()} for i, sample in enumerate(embeddings_dict.keys())}\n", + "save_speaker_mapping(MODEL_RUN_PATH, speaker_mapping)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#test load integrity\n", + "speaker_mapping_load = load_speaker_mapping(MODEL_RUN_PATH)\n", + "assert speaker_mapping == speaker_mapping_load\n", + "print(\"The file speakers.json has been exported to \",MODEL_RUN_PATH, ' with ', len(embeddings_dict.keys()), ' speakers')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018.ipynb b/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018.ipynb new file mode 100755 index 00000000..458422c0 --- /dev/null +++ b/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018.ipynb @@ -0,0 +1,637 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Demo-Mozilla-TTS-MultiSpeaker-jia-et-al-2018.ipynb", + "provenance": [], + "collapsed_sections": [ + "vnV-FigfvsS2", + "hkvv7gRcx4WV", + "QJ6VgT2a4vHW" + ] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "yZK6UdwSFnOO", + "colab_type": "text" + }, + "source": [ + "# **Download and install Mozilla TTS**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yvb0pX3WY6MN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import os \n", + "!git clone https://github.com/Edresson/TTS -b dev-gst-embeddings" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iB9nl2UEG3SY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!apt-get install espeak\n", + "os.chdir('TTS')\n", + "!pip install -r requirements.txt\n", + "!python setup.py develop\n", + "os.chdir('..')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w6Krn8k1inC_", + "colab_type": "text" + }, + "source": [ + "\n", + "\n", + "**Download Checkpoint**\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PiYHf3lKhi9z", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!wget -c -q --show-progress -O ./TTS-checkpoint.zip https://github.com/Edresson/TTS/releases/download/v1.0.0/Checkpoints-TTS-MultiSpeaker-Jia-et-al-2018.zip\n", + "!unzip ./TTS-checkpoint.zip\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MpYNgqrZcJKn", + "colab_type": "text" + }, + "source": [ + "**Utils Functions**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4KZA4b_CbMqx", + "colab_type": "code", + "colab": {} + }, + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import argparse\n", + "import json\n", + "# pylint: disable=redefined-outer-name, unused-argument\n", + "import os\n", + "import string\n", + "import time\n", + "import sys\n", + "import numpy as np\n", + "\n", + "TTS_PATH = \"../content/TTS\"\n", + "# add libraries into environment\n", + "sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n", + "\n", + "import torch\n", + "\n", + "from TTS.tts.utils.generic_utils import setup_model\n", + "from TTS.tts.utils.synthesis import synthesis\n", + "from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.io import load_config\n", + "from TTS.vocoder.utils.generic_utils import setup_generator\n", + "\n", + "\n", + "def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_fileid, speaker_embedding=None):\n", + " t_1 = time.time()\n", + " waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_fileid, None, False, CONFIG.enable_eos_bos_chars, use_gl, speaker_embedding=speaker_embedding)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " if not use_gl:\n", + " waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n", + " if use_cuda and not use_gl:\n", + " waveform = waveform.cpu()\n", + " if not use_gl:\n", + " waveform = waveform.numpy()\n", + " waveform = waveform.squeeze()\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " tps = (time.time() - t_1) / len(waveform)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", + " print(\" > Time per step: {}\".format(tps))\n", + " return waveform\n", + "\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ENA2OumIVeMA", + "colab_type": "text" + }, + "source": [ + "# **Vars definitions**\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jPD0d_XpVXmY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "TEXT = ''\n", + "OUT_PATH = 'tests-audios/'\n", + "# create output path\n", + "os.makedirs(OUT_PATH, exist_ok=True)\n", + "\n", + "SPEAKER_FILEID = None # if None use the first embedding from speakers.json\n", + "\n", + "# model vars \n", + "MODEL_PATH = 'best_model.pth.tar'\n", + "CONFIG_PATH = 'config.json'\n", + "SPEAKER_JSON = 'speakers.json'\n", + "\n", + "# vocoder vars\n", + "VOCODER_PATH = ''\n", + "VOCODER_CONFIG_PATH = ''\n", + "\n", + "USE_CUDA = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dV6cXXlfi72r", + "colab_type": "text" + }, + "source": [ + "# **Restore TTS Model**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "x1WgLFauWUPe", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# load the config\n", + "C = load_config(CONFIG_PATH)\n", + "C.forward_attn_mask = True\n", + "\n", + "# load the audio processor\n", + "ap = AudioProcessor(**C.audio)\n", + "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in C.keys():\n", + " symbols, phonemes = make_symbols(**C.characters)\n", + "\n", + "speaker_embedding = None\n", + "speaker_embedding_dim = None\n", + "num_speakers = 0\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " num_speakers = len(speaker_mapping)\n", + " if C.use_external_speaker_embedding_file:\n", + " if SPEAKER_FILEID is not None:\n", + " speaker_embedding = speaker_mapping[SPEAKER_FILEID]['embedding']\n", + " else: # if speaker_fileid is not specificated use the first sample in speakers.json\n", + " choise_speaker = list(speaker_mapping.keys())[0]\n", + " print(\" Speaker: \",choise_speaker.split('_')[0],'was chosen automatically', \"(this speaker seen in training)\")\n", + " speaker_embedding = speaker_mapping[choise_speaker]['embedding']\n", + " speaker_embedding_dim = len(speaker_embedding)\n", + "\n", + "# load the model\n", + "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)\n", + "cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))\n", + "model.load_state_dict(cp['model'])\n", + "model.eval()\n", + "\n", + "if USE_CUDA:\n", + " model.cuda()\n", + "\n", + "model.decoder.set_r(cp['r'])\n", + "\n", + "# load vocoder model\n", + "if VOCODER_PATH!= \"\":\n", + " VC = load_config(VOCODER_CONFIG_PATH)\n", + " vocoder_model = setup_generator(VC)\n", + " vocoder_model.load_state_dict(torch.load(VOCODER_PATH, map_location=\"cpu\")[\"model\"])\n", + " vocoder_model.remove_weight_norm()\n", + " if USE_CUDA:\n", + " vocoder_model.cuda()\n", + " vocoder_model.eval()\n", + "else:\n", + " vocoder_model = None\n", + " VC = None\n", + "\n", + "# synthesize voice\n", + "use_griffin_lim = VOCODER_PATH== \"\"\n", + "\n", + "if not C.use_external_speaker_embedding_file:\n", + " if SPEAKER_FILEID.isdigit():\n", + " SPEAKER_FILEID = int(SPEAKER_FILEID)\n", + " else:\n", + " SPEAKER_FILEID = None\n", + "else:\n", + " SPEAKER_FILEID = None\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tNvVEoE30qY6", + "colab_type": "text" + }, + "source": [ + "Synthesize sentence with Speaker\n", + "\n", + "> Stop running the cell to leave!\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2o8fXkVSyXOa", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",choise_speaker.split('_')[0], \"(this speaker seen in training)\")\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vnV-FigfvsS2", + "colab_type": "text" + }, + "source": [ + "# **Select Speaker**\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "RuCGOnJ_fgDV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "\n", + "# VCTK speakers not seen in training (new speakers)\n", + "VCTK_test_Speakers = [\"p225\", \"p234\", \"p238\", \"p245\", \"p248\", \"p261\", \"p294\", \"p302\", \"p326\", \"p335\", \"p347\"]\n", + "\n", + "# VCTK speakers seen in training\n", + "VCTK_train_Speakers = ['p244', 'p300', 'p303', 'p273', 'p292', 'p252', 'p254', 'p269', 'p345', 'p274', 'p363', 'p285', 'p351', 'p361', 'p295', 'p266', 'p307', 'p230', 'p339', 'p253', 'p310', 'p241', 'p256', 'p323', 'p237', 'p229', 'p298', 'p336', 'p276', 'p305', 'p255', 'p278', 'p299', 'p265', 'p267', 'p280', 'p260', 'p272', 'p262', 'p334', 'p283', 'p247', 'p246', 'p374', 'p297', 'p249', 'p250', 'p304', 'p240', 'p236', 'p312', 'p286', 'p263', 'p258', 'p313', 'p376', 'p279', 'p340', 'p362', 'p284', 'p231', 'p308', 'p277', 'p275', 'p333', 'p314', 'p330', 'p264', 'p226', 'p288', 'p343', 'p239', 'p232', 'p268', 'p270', 'p329', 'p227', 'p271', 'p228', 'p311', 'p301', 'p293', 'p364', 'p251', 'p317', 'p360', 'p281', 'p243', 'p287', 'p233', 'p259', 'p316', 'p257', 'p282', 'p306', 'p341', 'p318']\n", + "\n", + "\n", + "num_samples_speaker = 2 # In theory the more samples of the speaker the more similar to the real voice it will be!\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkvv7gRcx4WV", + "colab_type": "text" + }, + "source": [ + "## **Example select a VCTK seen speaker in training**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "BviNMI9UyCYz", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# get embedding\n", + "Speaker_choise = VCTK_train_Speakers[0] # choise one of training speakers\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " if C.use_external_speaker_embedding_file:\n", + " speaker_embeddings = []\n", + " for key in list(speaker_mapping.keys()):\n", + " if Speaker_choise in key:\n", + " if len(speaker_embeddings) < num_samples_speaker:\n", + " speaker_embeddings.append(speaker_mapping[key]['embedding'])\n", + " # takes the average of the embedings samples of the announcers\n", + " speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()\n", + " " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5e5_XnLsx3jg", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker seen in training)\")\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QJ6VgT2a4vHW" + }, + "source": [ + "## **Example select a VCTK not seen speaker in training (new Speakers)**\n", + "\n", + "\n", + "> Fitting new Speakers :)\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "SZS57ZK-4vHa", + "colab": {} + }, + "source": [ + "# get embedding\n", + "Speaker_choise = VCTK_test_Speakers[0] # choise one of training speakers\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " if C.use_external_speaker_embedding_file:\n", + " speaker_embeddings = []\n", + " for key in list(speaker_mapping.keys()):\n", + " if Speaker_choise in key:\n", + " if len(speaker_embeddings) < num_samples_speaker:\n", + " speaker_embeddings.append(speaker_mapping[key]['embedding'])\n", + " # takes the average of the embedings samples of the announcers\n", + " speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()\n", + " " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "bbs85vzz4vHo", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LEE6mQLh5Who" + }, + "source": [ + "# **Example Synthesizing with your own voice :)**\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "La70gSB65nrs", + "colab_type": "text" + }, + "source": [ + " Download and load GE2E Speaker Encoder " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "r0IEFZ0B5vQg", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!wget -c -q --show-progress -O ./SpeakerEncoder-checkpoint.zip https://github.com/Edresson/TTS/releases/download/v1.0.0/GE2E-SpeakerEncoder-iter25k.zip\n", + "!unzip ./SpeakerEncoder-checkpoint.zip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEH8HCTh5mF6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "SE_MODEL_RUN_PATH = \"GE2E-SpeakerEncoder/\"\n", + "SE_MODEL_PATH = os.path.join(SE_MODEL_RUN_PATH, \"best_model.pth.tar\")\n", + "SE_CONFIG_PATH =os.path.join(SE_MODEL_RUN_PATH, \"config.json\")\n", + "USE_CUDA = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tOwkfQqT6-Qo", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.speaker_encoder.model import SpeakerEncoder\n", + "se_config = load_config(SE_CONFIG_PATH)\n", + "se_ap = AudioProcessor(**se_config['audio'])\n", + "\n", + "se_model = SpeakerEncoder(**se_config.model)\n", + "se_model.load_state_dict(torch.load(SE_MODEL_PATH)['model'])\n", + "se_model.eval()\n", + "if USE_CUDA:\n", + " se_model.cuda()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0TLlbUFG8O36", + "colab_type": "text" + }, + "source": [ + "Upload a wav audio file in your voice.\n", + "\n", + "\n", + "> We recommend files longer than 3 seconds, the bigger the file the closer to your voice :)\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_FWwHPjJ8NXl", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from google.colab import files\n", + "file_list = files.upload()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WWOf6sgbBbGY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# extract embedding from wav files\n", + "speaker_embeddings = []\n", + "for name in file_list.keys():\n", + " if '.wav' in name:\n", + " mel_spec = se_ap.melspectrogram(se_ap.load_wav(name, sr=se_ap.sample_rate)).T\n", + " mel_spec = torch.FloatTensor(mel_spec[None, :, :])\n", + " if USE_CUDA:\n", + " mel_spec = mel_spec.cuda()\n", + " embedd = se_model.compute_embedding(mel_spec).cpu().detach().numpy().reshape(-1)\n", + " speaker_embeddings.append(embedd)\n", + " else:\n", + " print(\" You need upload Wav files, others files is not supported !!\")\n", + "\n", + "# takes the average of the embedings samples of the announcers\n", + "speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "xmItcGac5WiG", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with New Speaker using files: \",file_list.keys(), \"(this speaker not seen in training (new speaker))\")\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018_With_GST.ipynb b/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018_With_GST.ipynb new file mode 100755 index 00000000..e059461e --- /dev/null +++ b/notebooks/Demo_Mozilla_TTS_MultiSpeaker_jia_et_al_2018_With_GST.ipynb @@ -0,0 +1,834 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Demo-Mozilla-TTS-MultiSpeaker-jia-et-al-2018-With-GST.ipynb", + "provenance": [], + "collapsed_sections": [ + "yZK6UdwSFnOO", + "ENA2OumIVeMA", + "dV6cXXlfi72r", + "vnV-FigfvsS2", + "g_G_HweN04W-", + "LEE6mQLh5Who" + ], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "yZK6UdwSFnOO", + "colab_type": "text" + }, + "source": [ + "# **Download and install Mozilla TTS**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "yvb0pX3WY6MN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import os \n", + "!git clone https://github.com/Edresson/TTS -b dev-gst-embeddings" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iB9nl2UEG3SY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!apt-get install espeak\n", + "os.chdir('TTS')\n", + "!pip install -r requirements.txt\n", + "!python setup.py develop\n", + "os.chdir('..')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w6Krn8k1inC_", + "colab_type": "text" + }, + "source": [ + "\n", + "\n", + "**Download Checkpoint**\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PiYHf3lKhi9z", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!wget -c -q --show-progress -O ./TTS-checkpoint.zip https://github.com/Edresson/TTS/releases/download/v1.0.0/Checkpoints-TTS-MultiSpeaker-Jia-et-al-2018-with-GST.zip\n", + "!unzip ./TTS-checkpoint.zip\n", + "\n", + "# Download gst style example\n", + "!wget https://github.com/Edresson/TTS/releases/download/v1.0.0/gst-style-example.wav" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MpYNgqrZcJKn", + "colab_type": "text" + }, + "source": [ + "**Utils Functions**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4KZA4b_CbMqx", + "colab_type": "code", + "colab": {} + }, + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import argparse\n", + "import json\n", + "# pylint: disable=redefined-outer-name, unused-argument\n", + "import os\n", + "import string\n", + "import time\n", + "import sys\n", + "import numpy as np\n", + "\n", + "TTS_PATH = \"../content/TTS\"\n", + "# add libraries into environment\n", + "sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n", + "\n", + "import torch\n", + "\n", + "from TTS.tts.utils.generic_utils import setup_model\n", + "from TTS.tts.utils.synthesis import synthesis\n", + "from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.io import load_config\n", + "from TTS.vocoder.utils.generic_utils import setup_generator\n", + "\n", + "\n", + "def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_fileid, speaker_embedding=None, gst_style=None):\n", + " t_1 = time.time()\n", + " waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_fileid, gst_style, False, CONFIG.enable_eos_bos_chars, use_gl, speaker_embedding=speaker_embedding)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " if not use_gl:\n", + " waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n", + " if use_cuda and not use_gl:\n", + " waveform = waveform.cpu()\n", + " if not use_gl:\n", + " waveform = waveform.numpy()\n", + " waveform = waveform.squeeze()\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " tps = (time.time() - t_1) / len(waveform)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", + " print(\" > Time per step: {}\".format(tps))\n", + " return waveform\n", + "\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ENA2OumIVeMA", + "colab_type": "text" + }, + "source": [ + "# **Vars definitions**\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jPD0d_XpVXmY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "TEXT = ''\n", + "OUT_PATH = 'tests-audios/'\n", + "# create output path\n", + "os.makedirs(OUT_PATH, exist_ok=True)\n", + "\n", + "SPEAKER_FILEID = None # if None use the first embedding from speakers.json\n", + "\n", + "# model vars \n", + "MODEL_PATH = 'best_model.pth.tar'\n", + "CONFIG_PATH = 'config.json'\n", + "SPEAKER_JSON = 'speakers.json'\n", + "\n", + "# vocoder vars\n", + "VOCODER_PATH = ''\n", + "VOCODER_CONFIG_PATH = ''\n", + "\n", + "USE_CUDA = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dV6cXXlfi72r", + "colab_type": "text" + }, + "source": [ + "# **Restore TTS Model**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "x1WgLFauWUPe", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# load the config\n", + "C = load_config(CONFIG_PATH)\n", + "C.forward_attn_mask = True\n", + "\n", + "# load the audio processor\n", + "ap = AudioProcessor(**C.audio)\n", + "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in C.keys():\n", + " symbols, phonemes = make_symbols(**C.characters)\n", + "\n", + "speaker_embedding = None\n", + "speaker_embedding_dim = None\n", + "num_speakers = 0\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " num_speakers = len(speaker_mapping)\n", + " if C.use_external_speaker_embedding_file:\n", + " if SPEAKER_FILEID is not None:\n", + " speaker_embedding = speaker_mapping[SPEAKER_FILEID]['embedding']\n", + " else: # if speaker_fileid is not specificated use the first sample in speakers.json\n", + " choise_speaker = list(speaker_mapping.keys())[0]\n", + " print(\" Speaker: \",choise_speaker.split('_')[0],'was chosen automatically', \"(this speaker seen in training)\")\n", + " speaker_embedding = speaker_mapping[choise_speaker]['embedding']\n", + " speaker_embedding_dim = len(speaker_embedding)\n", + "\n", + "# load the model\n", + "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, num_speakers, C, speaker_embedding_dim)\n", + "cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))\n", + "model.load_state_dict(cp['model'])\n", + "model.eval()\n", + "\n", + "if USE_CUDA:\n", + " model.cuda()\n", + "\n", + "model.decoder.set_r(cp['r'])\n", + "\n", + "# load vocoder model\n", + "if VOCODER_PATH!= \"\":\n", + " VC = load_config(VOCODER_CONFIG_PATH)\n", + " vocoder_model = setup_generator(VC)\n", + " vocoder_model.load_state_dict(torch.load(VOCODER_PATH, map_location=\"cpu\")[\"model\"])\n", + " vocoder_model.remove_weight_norm()\n", + " if USE_CUDA:\n", + " vocoder_model.cuda()\n", + " vocoder_model.eval()\n", + "else:\n", + " vocoder_model = None\n", + " VC = None\n", + "\n", + "# synthesize voice\n", + "use_griffin_lim = VOCODER_PATH== \"\"\n", + "\n", + "if not C.use_external_speaker_embedding_file:\n", + " if SPEAKER_FILEID.isdigit():\n", + " SPEAKER_FILEID = int(SPEAKER_FILEID)\n", + " else:\n", + " SPEAKER_FILEID = None\n", + "else:\n", + " SPEAKER_FILEID = None\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tNvVEoE30qY6", + "colab_type": "text" + }, + "source": [ + "Synthesize sentence with Speaker\n", + "\n", + "> Stop running the cell to leave!\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2o8fXkVSyXOa", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",choise_speaker.split('_')[0], \"(this speaker seen in training)\")\n", + "gst_style = 'gst-style-example.wav'\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vnV-FigfvsS2", + "colab_type": "text" + }, + "source": [ + "# **Select Speaker**\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "RuCGOnJ_fgDV", + "colab_type": "code", + "colab": {} + }, + "source": [ + "\n", + "# VCTK speakers not seen in training (new speakers)\n", + "VCTK_test_Speakers = [\"p225\", \"p234\", \"p238\", \"p245\", \"p248\", \"p261\", \"p294\", \"p302\", \"p326\", \"p335\", \"p347\"]\n", + "\n", + "# VCTK speakers seen in training\n", + "VCTK_train_Speakers = ['p244', 'p300', 'p303', 'p273', 'p292', 'p252', 'p254', 'p269', 'p345', 'p274', 'p363', 'p285', 'p351', 'p361', 'p295', 'p266', 'p307', 'p230', 'p339', 'p253', 'p310', 'p241', 'p256', 'p323', 'p237', 'p229', 'p298', 'p336', 'p276', 'p305', 'p255', 'p278', 'p299', 'p265', 'p267', 'p280', 'p260', 'p272', 'p262', 'p334', 'p283', 'p247', 'p246', 'p374', 'p297', 'p249', 'p250', 'p304', 'p240', 'p236', 'p312', 'p286', 'p263', 'p258', 'p313', 'p376', 'p279', 'p340', 'p362', 'p284', 'p231', 'p308', 'p277', 'p275', 'p333', 'p314', 'p330', 'p264', 'p226', 'p288', 'p343', 'p239', 'p232', 'p268', 'p270', 'p329', 'p227', 'p271', 'p228', 'p311', 'p301', 'p293', 'p364', 'p251', 'p317', 'p360', 'p281', 'p243', 'p287', 'p233', 'p259', 'p316', 'p257', 'p282', 'p306', 'p341', 'p318']\n", + "\n", + "\n", + "num_samples_speaker = 2 # In theory the more samples of the speaker the more similar to the real voice it will be!\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkvv7gRcx4WV", + "colab_type": "text" + }, + "source": [ + "## **Example select a VCTK seen speaker in training**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "BviNMI9UyCYz", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# get embedding\n", + "Speaker_choise = VCTK_train_Speakers[0] # choise one of training speakers\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " if C.use_external_speaker_embedding_file:\n", + " speaker_embeddings = []\n", + " for key in list(speaker_mapping.keys()):\n", + " if Speaker_choise in key:\n", + " if len(speaker_embeddings) < num_samples_speaker:\n", + " speaker_embeddings.append(speaker_mapping[key]['embedding'])\n", + " # takes the average of the embedings samples of the announcers\n", + " speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()\n", + " " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5e5_XnLsx3jg", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker seen in training)\")\n", + "gst_style = 'gst-style-example.wav'\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QJ6VgT2a4vHW" + }, + "source": [ + "## **Example select a VCTK not seen speaker in training (new Speakers)**\n", + "\n", + "\n", + "> Fitting new Speakers :)\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "SZS57ZK-4vHa", + "colab": {} + }, + "source": [ + "# get embedding\n", + "Speaker_choise = VCTK_test_Speakers[0] # choise one of training speakers\n", + "# load speakers\n", + "if SPEAKER_JSON != '':\n", + " speaker_mapping = json.load(open(SPEAKER_JSON, 'r'))\n", + " if C.use_external_speaker_embedding_file:\n", + " speaker_embeddings = []\n", + " for key in list(speaker_mapping.keys()):\n", + " if Speaker_choise in key:\n", + " if len(speaker_embeddings) < num_samples_speaker:\n", + " speaker_embeddings.append(speaker_mapping[key]['embedding'])\n", + " # takes the average of the embedings samples of the announcers\n", + " speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()\n", + " " + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "bbs85vzz4vHo", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "gst_style = 'gst-style-example.wav'\n", + "while True:\n", + " TEXT = input(\"Enter sentence: \")\n", + " print(\" > Text: {}\".format(TEXT))\n", + " wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + " IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + " # save the results\n", + " file_name = TEXT.replace(\" \", \"_\")\n", + " file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + " out_path = os.path.join(OUT_PATH, file_name)\n", + " print(\" > Saving output to {}\".format(out_path))\n", + " ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g_G_HweN04W-", + "colab_type": "text" + }, + "source": [ + "# **Changing GST tokens manually (without wav reference)**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jyFP5syW2bjt", + "colab_type": "text" + }, + "source": [ + "You can define tokens manually, this way you can increase/decrease the function of a given GST token. For example a token is responsible for the length of the speaker's pauses, if you increase the value of that token you will have longer pauses and if you decrease it you will have shorter pauses." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SpwjDjCM2a3Y", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# set gst tokens, in this model we have 5 tokens\n", + "gst_style = {\"0\": 0, \"1\": 0, \"3\": 0, \"4\": 0}" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "qWChMbI_0z5X", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "uFjUi9xQ3mG3", + "colab_type": "code", + "colab": {} + }, + "source": [ + "gst_style = {\"0\": 0.9, \"1\": 0, \"3\": 0, \"4\": 0}\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Uw0d6gWg4L27", + "colab_type": "code", + "colab": {} + }, + "source": [ + "gst_style = {\"0\": -0.9, \"1\": 0, \"3\": 0, \"4\": 0}\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "V9izw4-54-Tl", + "colab_type": "code", + "colab": {} + }, + "source": [ + "gst_style = {\"0\": 0, \"1\": 0.9, \"3\": 0, \"4\": 0}\n", + "print(\"Synthesize sentence with Speaker: \",Speaker_choise.split('_')[0], \"(this speaker not seen in training (new speaker))\")\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LEE6mQLh5Who" + }, + "source": [ + "# **Example Synthesizing with your own voice :)**\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "La70gSB65nrs", + "colab_type": "text" + }, + "source": [ + " Download and load GE2E Speaker Encoder " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "r0IEFZ0B5vQg", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!wget -c -q --show-progress -O ./SpeakerEncoder-checkpoint.zip https://github.com/Edresson/TTS/releases/download/v1.0.0/GE2E-SpeakerEncoder-iter25k.zip\n", + "!unzip ./SpeakerEncoder-checkpoint.zip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "jEH8HCTh5mF6", + "colab_type": "code", + "colab": {} + }, + "source": [ + "SE_MODEL_RUN_PATH = \"GE2E-SpeakerEncoder/\"\n", + "SE_MODEL_PATH = os.path.join(SE_MODEL_RUN_PATH, \"best_model.pth.tar\")\n", + "SE_CONFIG_PATH =os.path.join(SE_MODEL_RUN_PATH, \"config.json\")\n", + "USE_CUDA = True" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tOwkfQqT6-Qo", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.speaker_encoder.model import SpeakerEncoder\n", + "se_config = load_config(SE_CONFIG_PATH)\n", + "se_ap = AudioProcessor(**se_config['audio'])\n", + "\n", + "se_model = SpeakerEncoder(**se_config.model)\n", + "se_model.load_state_dict(torch.load(SE_MODEL_PATH)['model'])\n", + "se_model.eval()\n", + "if USE_CUDA:\n", + " se_model.cuda()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0TLlbUFG8O36", + "colab_type": "text" + }, + "source": [ + "Upload one or more wav audio files in your voice.\n", + "\n", + "\n", + "> We recommend files longer than 3 seconds, the bigger the file the closer to your voice :)\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_FWwHPjJ8NXl", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# select one or more wav files\n", + "from google.colab import files\n", + "file_list = files.upload()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WWOf6sgbBbGY", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# extract embedding from wav files\n", + "speaker_embeddings = []\n", + "for name in file_list.keys():\n", + " if '.wav' in name:\n", + " mel_spec = se_ap.melspectrogram(se_ap.load_wav(name, sr=se_ap.sample_rate)).T\n", + " mel_spec = torch.FloatTensor(mel_spec[None, :, :])\n", + " if USE_CUDA:\n", + " mel_spec = mel_spec.cuda()\n", + " embedd = se_model.compute_embedding(mel_spec).cpu().detach().numpy().reshape(-1)\n", + " speaker_embeddings.append(embedd)\n", + " else:\n", + " print(\"You need upload Wav files, others files is not supported !!\")\n", + "\n", + "# takes the average of the embedings samples of the announcers\n", + "speaker_embedding = np.mean(np.array(speaker_embeddings), axis=0).tolist()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AQ7eP31d9yzq", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import IPython\n", + "from IPython.display import Audio\n", + "print(\"Synthesize sentence with New Speaker using files: \",file_list.keys(), \"(this speaker not seen in training (new speaker))\")\n", + "gst_style = {\"0\": 0, \"1\": 0.0, \"3\": 0, \"4\": 0}\n", + "gst_style = 'gst-style-example.wav'\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "11i10yE1-LMJ", + "colab_type": "text" + }, + "source": [ + "Uploading your own GST reference wav file" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "eKohSQG1-KkT", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# select one wav file for GST reference\n", + "from google.colab import files\n", + "file_list = files.upload()\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "xmItcGac5WiG", + "colab": {} + }, + "source": [ + "print(\"Synthesize sentence with New Speaker using files: \",file_list.keys(), \"(this speaker not seen in training (new speaker))\")\n", + "gst_style = list(file_list.keys())[0]\n", + "TEXT = input(\"Enter sentence: \")\n", + "print(\" > Text: {}\".format(TEXT))\n", + "wav = tts(model, vocoder_model, TEXT, C, USE_CUDA, ap, use_griffin_lim, SPEAKER_FILEID, speaker_embedding=speaker_embedding, gst_style=gst_style)\n", + "IPython.display.display(Audio(wav, rate=ap.sample_rate))\n", + "# save the results\n", + "file_name = TEXT.replace(\" \", \"_\")\n", + "file_name = file_name.translate(\n", + " str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'\n", + "out_path = os.path.join(OUT_PATH, file_name)\n", + "print(\" > Saving output to {}\".format(out_path))\n", + "ap.save_wav(wav, out_path)" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/GE2E-CorentinJ-ExtractSpeakerEmbeddings-by-sample.ipynb b/notebooks/GE2E-CorentinJ-ExtractSpeakerEmbeddings-by-sample.ipynb new file mode 100644 index 00000000..576a95fe --- /dev/null +++ b/notebooks/GE2E-CorentinJ-ExtractSpeakerEmbeddings-by-sample.ipynb @@ -0,0 +1,25495 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a noteboook used to generate the speaker embeddings with the CorentinJ GE2E model trained with Angular Prototypical loss for multi-speaker training.\n", + "\n", + "Before running this script please DON'T FORGET:\n", + "- to set the right paths in the cell below.\n", + "\n", + "Repositories:\n", + "- TTS: https://github.com/mozilla/TTS\n", + "- CorentinJ GE2E: https://github.com/Edresson/GE2E-Speaker-Encoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import os\n", + "import importlib\n", + "import random\n", + "import librosa\n", + "import torch\n", + "\n", + "import numpy as np\n", + "from TTS.utils.io import load_config\n", + "from tqdm import tqdm\n", + "from TTS.tts.utils.speakers import save_speaker_mapping, load_speaker_mapping\n", + "\n", + "# you may need to change this depending on your system\n", + "os.environ['CUDA_VISIBLE_DEVICES']='0'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'Real-Time-Voice-Cloning'...\n", + "remote: Enumerating objects: 5, done.\u001b[K\n", + "remote: Counting objects: 100% (5/5), done.\u001b[K\n", + "remote: Compressing objects: 100% (5/5), done.\u001b[K\n", + "remote: Total 2508 (delta 0), reused 3 (delta 0), pack-reused 2503\u001b[K\n", + "Receiving objects: 100% (2508/2508), 360.78 MiB | 17.84 MiB/s, done.\n", + "Resolving deltas: 100% (1387/1387), done.\n", + "Checking connectivity... done.\n" + ] + } + ], + "source": [ + "# Clone encoder \n", + "!git clone https://github.com/CorentinJ/Real-Time-Voice-Cloning.git\n", + "os.chdir('Real-Time-Voice-Cloning/')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#Install voxceleb_trainer Requeriments\n", + "!python -m pip install umap-learn visdom webrtcvad librosa>=0.5.1 matplotlib>=2.0.2 numpy>=1.14.0 scipy>=1.0.0 tqdm sounddevice Unidecode inflect multiprocess numba" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-08-05 06:51:05-- https://github.com/Edresson/Real-Time-Voice-Cloning/releases/download/checkpoints/pretrained.zip\n", + "Resolving github.com (github.com)... 18.231.5.6\n", + "Connecting to github.com (github.com)|18.231.5.6|:443... connected.\n", + "HTTP request sent, awaiting response... 301 Moved Permanently\n", + "Location: https://github.com/Edresson/GE2E-Speaker-Encoder/releases/download/checkpoints/pretrained.zip [following]\n", + "--2020-08-05 06:51:05-- https://github.com/Edresson/GE2E-Speaker-Encoder/releases/download/checkpoints/pretrained.zip\n", + "Reusing existing connection to github.com:443.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/263893598/f7f31d80-96df-11ea-8345-261fc35f9849?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200805%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200805T101614Z&X-Amz-Expires=300&X-Amz-Signature=df7724c28668ebd5dfbcc6a9b51f6afb78193c30119f3a1c3eef678188aabd1e&X-Amz-SignedHeaders=host&actor_id=0&repo_id=263893598&response-content-disposition=attachment%3B%20filename%3Dpretrained.zip&response-content-type=application%2Foctet-stream [following]\n", + "--2020-08-05 06:51:05-- https://github-production-release-asset-2e65be.s3.amazonaws.com/263893598/f7f31d80-96df-11ea-8345-261fc35f9849?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200805%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200805T101614Z&X-Amz-Expires=300&X-Amz-Signature=df7724c28668ebd5dfbcc6a9b51f6afb78193c30119f3a1c3eef678188aabd1e&X-Amz-SignedHeaders=host&actor_id=0&repo_id=263893598&response-content-disposition=attachment%3B%20filename%3Dpretrained.zip&response-content-type=application%2Foctet-stream\n", + "Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.18.24\n", + "Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.18.24|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 383640573 (366M) [application/octet-stream]\n", + "Saving to: ‘pretrained.zip’\n", + "\n", + "pretrained.zip 100%[===================>] 365,87M 6,62MB/s in 56s \n", + "\n", + "2020-08-05 06:52:03 (6,48 MB/s) - ‘pretrained.zip’ saved [383640573/383640573]\n", + "\n", + "Archive: pretrained.zip\n", + " creating: encoder/saved_models/\n", + " inflating: encoder/saved_models/pretrained.pt \n", + " creating: synthesizer/saved_models/\n", + " creating: synthesizer/saved_models/logs-pretrained/\n", + " creating: synthesizer/saved_models/logs-pretrained/taco_pretrained/\n", + " extracting: synthesizer/saved_models/logs-pretrained/taco_pretrained/checkpoint \n", + " inflating: synthesizer/saved_models/logs-pretrained/taco_pretrained/tacotron_model.ckpt-278000.data-00000-of-00001 \n", + " inflating: synthesizer/saved_models/logs-pretrained/taco_pretrained/tacotron_model.ckpt-278000.index \n", + " inflating: synthesizer/saved_models/logs-pretrained/taco_pretrained/tacotron_model.ckpt-278000.meta \n", + " creating: vocoder/saved_models/\n", + " creating: vocoder/saved_models/pretrained/\n", + " inflating: vocoder/saved_models/pretrained/pretrained.pt \n" + ] + } + ], + "source": [ + "#Download encoder Checkpoint\n", + "!wget https://github.com/Edresson/Real-Time-Voice-Cloning/releases/download/checkpoints/pretrained.zip\n", + "!unzip pretrained.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from encoder import inference as encoder\n", + "from encoder.params_model import model_embedding_size as speaker_embedding_size\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preparing the encoder, the synthesizer and the vocoder...\n", + "Loaded encoder \"pretrained.pt\" trained to step 1564501\n", + "Testing your configuration with small inputs.\n", + "\tTesting the encoder...\n", + "(256,)\n" + ] + } + ], + "source": [ + "print(\"Preparing the encoder, the synthesizer and the vocoder...\")\n", + "encoder.load_model(Path('encoder/saved_models/pretrained.pt'))\n", + "print(\"Testing your configuration with small inputs.\")\n", + "# Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's\n", + "# sampling rate, which may differ.\n", + "# If you're unfamiliar with digital audio, know that it is encoded as an array of floats \n", + "# (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1.\n", + "# The sampling rate is the number of values (samples) recorded per second, it is set to\n", + "# 16000 for the encoder. Creating an array of length will always correspond \n", + "# to an audio of 1 second.\n", + "print(\"\\tTesting the encoder...\")\n", + "\n", + "wav = np.zeros(encoder.sampling_rate) \n", + "embed = encoder.embed_utterance(wav)\n", + "print(embed.shape)\n", + "\n", + "# Embeddings are L2-normalized (this isn't important here, but if you want to make your own \n", + "# embeddings it will be).\n", + "#embed /= np.linalg.norm(embed) # for random embedding\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "SAVE_PATH = '../'" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Set constants\n", + "DATASETS_NAME = ['vctk'] # list the datasets\n", + "DATASETS_PATH = ['../../../../../datasets/VCTK-Corpus-removed-silence/']\n", + "DATASETS_METAFILE = ['']\n", + "USE_CUDA = True" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " 0%| | 0/44063 [00:00 wave file [path to wave] or + // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} + // with the dictionary being len(dict) <= len(gst_style_tokens). + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + } +} diff --git a/tests/inputs/test_train_config.json b/tests/inputs/test_train_config.json index 951fe4a3..bea4cbb7 100644 --- a/tests/inputs/test_train_config.json +++ b/tests/inputs/test_train_config.json @@ -1,3 +1,4 @@ +<<<<<<< HEAD:tests/inputs/test_train_config.json { "model": "Tacotron2", "run_name": "test_sample_dataset_run", @@ -150,3 +151,161 @@ } +======= +{ + "model": "Tacotron2", + "run_name": "ljspeech-ddc-bn", + "run_description": "tacotron2 with ddc and batch-normalization", + + // AUDIO PARAMETERS + "audio":{ + // stft parameters + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // Griffin-Lim + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 20, + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // TRAINING + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":16, + "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + "ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": false, // use noam warmup and lr schedule. + "grad_clip": 1.0, // upper limit for gradients for clipping. + "epochs": 1000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "wd": 0.000001, // Weight decay weight. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. + + // TACOTRON PRENET + "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. + "prenet_type": "bn", // "original" or "bn". + "prenet_dropout": false, // enable/disable dropout at prenet. + + // TACOTRON ATTENTION + "attention_type": "original", // 'original' or 'graves' + "attention_heads": 4, // number of attention heads (only for 'graves') + "attention_norm": "sigmoid", // softmax or sigmoid. + "windowing": false, // Enables attention windowing. Used only in eval mode. + "use_forward_attn": false, // if it uses forward attention. In general, it aligns faster. + "forward_attn_mask": false, // Additional masking forcing monotonicity only in eval mode. + "transition_agent": false, // enable/disable transition agent of forward attention. + "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. + "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. + "double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/ + "ddc_r": 7, // reduction rate for coarse decoder. + + // STOPNET + "stopnet": true, // Train stopnet predicting the end of synthesis. + "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step:": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + // DATA LOADING + "text_cleaner": "phoneme_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 153, // DATASET-RELATED: maximum text length + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/", + + // PHONEMES + "phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "use_gst": true, // use global style tokens + "gst": { // gst parameter if gst is enabled + "gst_style_input": null, // Condition the style input either on a + // -> wave file [path to wave] or + // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} + // with the dictionary being len(dict) == len(gst_style_tokens). + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + }, + + // DATASETS + "datasets": // List of datasets. They all merged and they get different speaker_ids. + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null + } + ] +} + +>>>>>>> Added support for Tacotron2 GST + abbility to condition style input with wav or tokens:config.json diff --git a/tests/outputs/dummy_model_config.json b/tests/outputs/dummy_model_config.json index d2e2fca0..b032f191 100644 --- a/tests/outputs/dummy_model_config.json +++ b/tests/outputs/dummy_model_config.json @@ -83,6 +83,20 @@ "use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "text_cleaner": "phoneme_cleaners", - "use_speaker_embedding": false // whether to use additional embeddings for separate speakers + "use_speaker_embedding": false, // whether to use additional embeddings for separate speakers + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "use_gst": true, // use global style tokens + "gst": { // gst parameter if gst is enabled + "gst_style_input": null, // Condition the style input either on a + // -> wave file [path to wave] or + // -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15} + // with the dictionary being len(dict) <= len(gst_style_tokens). + "gst_embedding_dim": 512, + "gst_num_heads": 4, + "gst_style_tokens": 10 + } } + diff --git a/tests/test_encoder.py b/tests/test_encoder.py index 711ad195..46266f29 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -4,7 +4,7 @@ import unittest import torch as T from tests import get_tests_input_path -from mozilla_voice_tts.speaker_encoder.loss import GE2ELoss +from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder from mozilla_voice_tts.utils.io import load_config @@ -59,6 +59,7 @@ class GE2ELossTests(unittest.TestCase): dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim loss = GE2ELoss(loss_method="softmax") output = loss.forward(dummy_input) + assert output.item() >= 0.0 # check speaker loss with orthogonal d-vectors dummy_input = T.empty(3, 64) dummy_input = T.nn.init.orthogonal(dummy_input) @@ -73,6 +74,34 @@ class GE2ELossTests(unittest.TestCase): output = loss.forward(dummy_input) assert output.item() < 0.005 +class AngleProtoLossTests(unittest.TestCase): + # pylint: disable=R0201 + def test_in_out(self): + # check random input + dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() >= 0.0 + + # check all zeros + dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() >= 0.0 + + # check speaker loss with orthogonal d-vectors + dummy_input = T.empty(3, 64) + dummy_input = T.nn.init.orthogonal(dummy_input) + dummy_input = T.cat( + [ + dummy_input[0].repeat(5, 1, 1).transpose(0, 1), + dummy_input[1].repeat(5, 1, 1).transpose(0, 1), + dummy_input[2].repeat(5, 1, 1).transpose(0, 1), + ] + ) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() < 0.005 # class LoaderTest(unittest.TestCase): # def test_output(self): diff --git a/tests/test_layers.py b/tests/test_layers.py index bf036f5c..0b5315c5 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -58,8 +58,7 @@ class DecoderTests(unittest.TestCase): trans_agent=True, forward_attn_mask=True, location_attn=True, - separate_stopnet=True, - speaker_embedding_dim=0) + separate_stopnet=True) dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) @@ -71,38 +70,6 @@ class DecoderTests(unittest.TestCase): assert output.shape[2] == 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 - @staticmethod - def test_in_out_multispeaker(): - layer = Decoder( - in_channels=256, - frame_channels=80, - r=2, - memory_size=4, - attn_windowing=False, - attn_norm="sigmoid", - attn_K=5, - attn_type="graves", - prenet_type='original', - prenet_dropout=True, - forward_attn=True, - trans_agent=True, - forward_attn_mask=True, - location_attn=True, - separate_stopnet=True, - speaker_embedding_dim=80) - dummy_input = T.rand(4, 8, 256) - dummy_memory = T.rand(4, 2, 80) - dummy_embed = T.rand(4, 80) - - output, alignment, stop_tokens = layer( - dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed) - - assert output.shape[0] == 4 - assert output.shape[1] == 80, "size not {}".format(output.shape[1]) - assert output.shape[2] == 2, "size not {}".format(output.shape[2]) - assert stop_tokens.shape[0] == 4 - - class EncoderTests(unittest.TestCase): def test_in_out(self): #pylint: disable=no-self-use layer = Encoder(128) diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index 2faccd75..28d39de5 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -9,6 +9,7 @@ from torch import nn, optim from mozilla_voice_tts.tts.layers.losses import MSELossMasked from mozilla_voice_tts.tts.models.tacotron2 import Tacotron2 from mozilla_voice_tts.utils.io import load_config +from mozilla_voice_tts.utils.audio import AudioProcessor #pylint: disable=unused-variable @@ -18,6 +19,9 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +ap = AudioProcessor(**c.audio) +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + class TacotronTrainTest(unittest.TestCase): def test_train_step(self): # pylint: disable=no-self-use @@ -70,3 +74,167 @@ class TacotronTrainTest(unittest.TestCase): ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) count += 1 + + +class MultiSpeakeTacotronTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_embeddings = torch.rand(8, 55).to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, speaker_embedding_dim=55).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(5): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 + +class TacotronGSTTrainTest(unittest.TestCase): + #pylint: disable=no-self-use + def test_train_step(self): + # with random gst mel style + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(10): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + name, param = name_param + if name == 'gst_layer.encoder.recurrence.weight_hh_l0': + #print(param.grad) + continue + assert (param != param_ref).any( + ), "param {} {} with shape {} not updated!! \n{}\n{}".format( + name, count, param.shape, param, param_ref) + count += 1 + + # with file gst style + mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :30].unsqueeze(0).transpose(1, 2).to(device) + mel_spec = mel_spec.repeat(8, 1, 1) + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(10): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + name, param = name_param + if name == 'gst_layer.encoder.recurrence.weight_hh_l0': + #print(param.grad) + continue + assert (param != param_ref).any( + ), "param {} {} with shape {} not updated!! \n{}\n{}".format( + name, count, param.shape, param, param_ref) + count += 1 diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index d15a6705..0b80243f 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -9,6 +9,7 @@ from torch import nn, optim from mozilla_voice_tts.tts.layers.losses import L1LossMasked from mozilla_voice_tts.tts.models.tacotron import Tacotron from mozilla_voice_tts.utils.io import load_config +from mozilla_voice_tts.utils.audio import AudioProcessor #pylint: disable=unused-variable @@ -18,6 +19,9 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +ap = AudioProcessor(**c.audio) +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + def count_parameters(model): r"""Count number of trainable parameters in a network""" @@ -31,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device) @@ -49,7 +53,7 @@ class TacotronTrainTest(unittest.TestCase): model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['num_freq'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size @@ -85,15 +89,78 @@ class TacotronTrainTest(unittest.TestCase): count, param.shape, param, param_ref) count += 1 - -class TacotronGSTTrainTest(unittest.TestCase): +class MultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_embeddings = torch.rand(8, 55).to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > + 0.0).unsqueeze(2).float().squeeze() + + criterion = L1LossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron( + num_chars=32, + num_speakers=5, + postnet_output_dim=c.audio['fft_size'], + decoder_output_dim=c.audio['num_mels'], + r=c.r, + memory_size=c.memory_size, + speaker_embedding_dim=55, + ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + model.train() + print(" > Num parameters for Tacotron model:%s" % + (count_parameters(model))) + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for _ in range(5): + mel_out, linear_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, + speaker_embeddings=speaker_embeddings) + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(linear_out, linear_spec, + mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 + +class TacotronGSTTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + # with random gst mel style input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 120, (8, )).long().to(device) mel_lengths[-1] = 120 stop_targets = torch.zeros(8, 120, 1).float().to(device) @@ -113,13 +180,82 @@ class TacotronGSTTrainTest(unittest.TestCase): num_chars=32, num_speakers=5, gst=True, - postnet_output_dim=c.audio['num_freq'], + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor model.train() - print(model) + # print(model) + print(" > Num parameters for Tacotron GST model:%s" % + (count_parameters(model))) + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for _ in range(10): + mel_out, linear_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(linear_out, linear_spec, + mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 + + # with file gst style + mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :120].unsqueeze(0).transpose(1, 2).to(device) + mel_spec = mel_spec.repeat(8, 1, 1) + + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 + linear_spec = torch.rand(8, mel_spec.size(1), c.audio['fft_size']).to(device) + mel_lengths = torch.randint(20, mel_spec.size(1), (8, )).long().to(device) + mel_lengths[-1] = mel_spec.size(1) + stop_targets = torch.zeros(8, mel_spec.size(1), 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > + 0.0).unsqueeze(2).float().squeeze() + + criterion = L1LossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron( + num_chars=32, + num_speakers=5, + gst=True, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], + postnet_output_dim=c.audio['fft_size'], + decoder_output_dim=c.audio['num_mels'], + r=c.r, + memory_size=c.memory_size + ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + model.train() + # print(model) print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) diff --git a/utils/generic_utils.py b/utils/generic_utils.py new file mode 100644 index 00000000..3bb99e08 --- /dev/null +++ b/utils/generic_utils.py @@ -0,0 +1,374 @@ +import os +import glob +import torch +import shutil +import datetime +import subprocess +import importlib +import numpy as np +from collections import Counter + + +def get_git_branch(): + try: + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") + if line.startswith("*")) + current.replace("* ", "") + except subprocess.CalledProcessError: + current = "inside_docker" + return current + + +def get_commit_hash(): + """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" + # try: + # subprocess.check_output(['git', 'diff-index', '--quiet', + # 'HEAD']) # Verify client is clean + # except: + # raise RuntimeError( + # " !! Commit before training to get the commit hash.") + try: + commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() + # Not copying .git folder into docker container + except subprocess.CalledProcessError: + commit = "0000000" + print(' > Git Hash: {}'.format(commit)) + return commit + + +def create_experiment_folder(root_path, model_name, debug): + """ Create a folder with the current date and time """ + date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") + if debug: + commit_hash = 'debug' + else: + commit_hash = get_commit_hash() + output_folder = os.path.join( + root_path, model_name + '-' + date_str + '-' + commit_hash) + os.makedirs(output_folder, exist_ok=True) + print(" > Experiment folder: {}".format(output_folder)) + return output_folder + + +def remove_experiment_folder(experiment_path): + """Check folder if there is a checkpoint, otherwise remove the folder""" + + checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") + if not checkpoint_files: + if os.path.exists(experiment_path): + shutil.rmtree(experiment_path, ignore_errors=True) + print(" ! Run is removed from {}".format(experiment_path)) + else: + print(" ! Run is kept in {}".format(experiment_path)) + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def split_dataset(items): + is_multi_speaker = False + speakers = [item[-1] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + eval_split_size = 500 if len(items) * 0.01 > 500 else int( + len(items) * 0.01) + assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples." + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + # most stupid code ever -- Fix it ! + while len(items_eval) < eval_split_size: + speakers = [item[-1] for item in items] + speaker_counter = Counter(speakers) + item_idx = np.random.randint(0, len(items)) + if speaker_counter[items[item_idx][-1]] > 1: + items_eval.append(items[item_idx]) + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.to(sequence_length.device) + seq_length_expand = ( + sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + # B x T_max + return seq_range_expand < seq_length_expand + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint_state.items() if k in model_dict + } + # 2. filter out different size layers + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if v.numel() == model_dict[k].numel() + } + # 3. skip reinit layers + if c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if reinit_layer_name not in k + } + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), + len(model_dict))) + return model_dict + + +def setup_model(num_chars, num_speakers, c): + print(" > Using model: {}".format(c.model)) + MyModel = importlib.import_module('TTS.models.' + c.model.lower()) + MyModel = getattr(MyModel, c.model) + if c.model.lower() in "tacotron": + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), + decoder_output_dim=c.audio['num_mels'], + gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], + memory_size=c.memory_size, + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r) + elif c.model.lower() == "tacotron2": + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], + gst=c.use_gst, + gst_embedding_dim=c.gst['gst_embedding_dim'], + gst_num_heads=c.gst['gst_num_heads'], + gst_style_tokens=c.gst['gst_style_tokens'], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r) + return model + +class KeepAverage(): + def __init__(self): + self.avg_values = {} + self.iters = {} + + def __getitem__(self, key): + return self.avg_values[key] + + def items(self): + return self.avg_values.items() + + def add_value(self, name, init_val=0, init_iter=0): + self.avg_values[name] = init_val + self.iters[name] = init_iter + + def update_value(self, name, value, weighted_avg=False): + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) + else: + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * \ + self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] + + def add_values(self, name_dict): + for key, value in name_dict.items(): + self.add_value(key, init_val=value) + + def update_values(self, value_dict): + for key, value in value_dict.items(): + self.update_value(key, value) + + +def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): + if alternative in c.keys() and c[alternative] is not None: + return + if restricted: + assert name in c.keys(), f' [!] {name} not defined in config.json' + if name in c.keys(): + if max_val: + assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' + if min_val: + assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' + if enum_list: + assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' + if val_type: + assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + + +def check_config(c): + _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) + _check_argument('run_name', c, restricted=True, val_type=str) + _check_argument('run_description', c, val_type=str) + + # AUDIO + _check_argument('audio', c, restricted=True, val_type=dict) + + # audio processing parameters + _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + _check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') + _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # vocabulary parameters + _check_argument('characters', c, restricted=False, val_type=dict) + _check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + + # normalization parameters + _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) + _check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) + _check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) + _check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100) + _check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) + _check_argument('trim_db', c['audio'], restricted=True, val_type=int) + + # training parameters + _check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('r', c, restricted=True, val_type=int, min_val=1) + _check_argument('gradual_training', c, restricted=False, val_type=list) + _check_argument('loss_masking', c, restricted=True, val_type=bool) + # _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) + + # validation parameters + _check_argument('run_eval', c, restricted=True, val_type=bool) + _check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) + _check_argument('test_sentences_file', c, restricted=False, val_type=str) + + # optimizer + _check_argument('noam_schedule', c, restricted=False, val_type=bool) + _check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) + _check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + _check_argument('lr', c, restricted=True, val_type=float, min_val=0) + _check_argument('wd', c, restricted=True, val_type=float, min_val=0) + _check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + _check_argument('seq_len_norm', c, restricted=True, val_type=bool) + + # tacotron prenet + _check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) + _check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) + _check_argument('prenet_dropout', c, restricted=True, val_type=bool) + + # attention + _check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) + _check_argument('attention_heads', c, restricted=True, val_type=int) + _check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) + _check_argument('windowing', c, restricted=True, val_type=bool) + _check_argument('use_forward_attn', c, restricted=True, val_type=bool) + _check_argument('forward_attn_mask', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('location_attn', c, restricted=True, val_type=bool) + _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) + _check_argument('double_decoder_consistency', c, restricted=True, val_type=bool) + _check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) + + # stopnet + _check_argument('stopnet', c, restricted=True, val_type=bool) + _check_argument('separate_stopnet', c, restricted=True, val_type=bool) + + # tensorboard + _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('checkpoint', c, restricted=True, val_type=bool) + _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + + # dataloading + # pylint: disable=import-outside-toplevel + from TTS.utils.text import cleaners + _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) + _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) + _check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) + _check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) + _check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) + + # paths + _check_argument('output_path', c, restricted=True, val_type=str) + + # multi-speaker + _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) + + # GST + _check_argument('use_gst', c, restricted=True, val_type=bool) + _check_argument('gst', c, restricted=True, val_type=dict) + _check_argument('gst_style_input', c['gst'], restricted=True, val_type=str) + _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) + _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) + _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1) + + # datasets - checking only the first entry + _check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + _check_argument('name', dataset_entry, restricted=True, val_type=str) + _check_argument('path', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)