diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 81e7a9f2..45e34a5e 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -59,6 +59,12 @@ def setup_loader(ap, r, is_val=False, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose, speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() + sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -112,7 +118,7 @@ def format_data(data): avg_text_length, avg_spec_length, attn_mask, item_idx -def data_depended_init(model, ap): +def data_depended_init(data_loader, model, ap): """Data depended initialization for activation normalization.""" if hasattr(model, 'module'): for f in model.module.decoder.flows: @@ -123,7 +129,6 @@ def data_depended_init(model, ap): if getattr(f, "set_ddi", False): f.set_ddi(True) - data_loader = setup_loader(ap, 1, is_val=False) model.train() print(" > Data depended initialization ... ") num_iter = 0 @@ -152,10 +157,9 @@ def data_depended_init(model, ap): return model -def train(model, criterion, optimizer, scheduler, +def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): - data_loader = setup_loader(ap, 1, is_val=False, - verbose=(epoch == 0)) + model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -308,8 +312,7 @@ def train(model, criterion, optimizer, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, 1, is_val=True) +def evaluate(data_loader, model, criterion, ap, global_step, epoch): model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -533,14 +536,18 @@ def main(args): # pylint: disable=redefined-outer-name if 'best_loss' not in locals(): best_loss = float('inf') + # define dataloaders + train_loader = setup_loader(ap, 1, is_val=False, verbose=True) + eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) + global_step = args.restore_step - model = data_depended_init(model, ap) + model = data_depended_init(train_loader, model, ap) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(model, criterion, optimizer, + train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index b3fbc415..03032bfc 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -61,6 +61,12 @@ def setup_loader(ap, r, is_val=False, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose, speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() + sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -123,10 +129,8 @@ def format_data(data): return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length -def train(model, criterion, optimizer, optimizer_st, scheduler, +def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st): - data_loader = setup_loader(ap, model.decoder.r, is_val=False, - verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -324,8 +328,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping) +def evaluate(data_loader, model, criterion, ap, global_step, epoch): model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -583,6 +586,13 @@ def main(args): # pylint: disable=redefined-outer-name if 'best_loss' not in locals(): best_loss = float('inf') + # define data loaders + train_loader = setup_loader(ap, + model.decoder.r, + is_val=False, + verbose=True) + eval_loader = setup_loader(ap, model.decoder.r, is_val=True) + global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) @@ -594,16 +604,27 @@ def main(args): # pylint: disable=redefined-outer-name if c.bidirectional_decoder: model.decoder_backward.set_r(r) print("\n > Number of output frames:", model.decoder.r) - train_avg_loss_dict, global_step = train(model, criterion, optimizer, + train_avg_loss_dict, global_step = train(train_loader, model, + criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch, scaler, scaler_st) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + global_step, epoch, scaler, + scaler_st) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, + global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: target_loss = eval_avg_loss_dict['avg_postnet_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None) + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + c.r, + OUT_PATH, + scaler=scaler.state_dict() if c.mixed_precision else None) if __name__ == '__main__': diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 7274fd9d..766aca40 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -131,6 +131,7 @@ "batch_group_size": 4, //Number of batches to shuffle after bucketing. "min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 153, // DATASET-RELATED: maximum text length + "compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage. // PATHS "output_path": "/home/erogol/Models/LJSpeech/", diff --git a/TTS/tts/configs/glow_tts_gated_conv.json b/TTS/tts/configs/glow_tts_gated_conv.json index dbcdbbde..d34fbaf0 100644 --- a/TTS/tts/configs/glow_tts_gated_conv.json +++ b/TTS/tts/configs/glow_tts_gated_conv.json @@ -105,6 +105,7 @@ "min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 500, // DATASET-RELATED: maximum text length "compute_f0": false, // compute f0 values in data-loader + "compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage. // PATHS "output_path": "/home/erogol/Models/LJSpeech/", diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 9b8cbdd3..5524c379 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -63,6 +63,7 @@ class MyDataset(Dataset): self.enable_eos_bos = enable_eos_bos self.speaker_mapping = speaker_mapping self.verbose = verbose + self.input_seq_computed = False if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) if self.verbose: @@ -71,7 +72,6 @@ class MyDataset(Dataset): if use_phonemes: print(" | > phoneme language: {}".format(phoneme_language)) print(" | > Number of instances : {}".format(len(self.items))) - self.sort_items() def load_wav(self, filename): audio = self.ap.load_wav(filename)