diff --git a/.compute b/.compute index 63dea7a7..2dbc7bb2 100644 --- a/.compute +++ b/.compute @@ -4,13 +4,13 @@ yes | apt-get install ffmpeg yes | apt-get install espeak yes | apt-get install tmux yes | apt-get install zsh -pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl +# pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl # wget https://www.dropbox.com/s/m8waow6b3ydpf6h/MozillaDataset.tar.gz?dl=0 -O /data/rw/home/mozilla.tar wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh sudo sh install.sh python3 setup.py develop # cp -R ${USER_DIR}/GermanData ../tmp/ -python3 distribute.py --config_path config_libritts.json --data_path /data/rw/home/LibriTTS/train-clean-360/ +python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/ # cp -R ${USER_DIR}/Mozilla_22050 ../tmp/ # python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/ -while true; do sleep 1000000; done +# while true; do sleep 1000000; done diff --git a/config.json b/config.json index 24d26e16..ee83f660 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { - "run_name": "mozilla-no-loc-fattn-stopnet-sigmoid-loss_masking", - "run_description": "using forward attention, with original prenet, loss masking,separate stopnet, sigmoid. Compare this with 4817. Pytorch DPP", + "run_name": "ljspeech", + "run_description": "gradual training with prenet frame size 1 + no maxout for cbhg + symmetric norm.", "audio":{ // Audio processing parameters @@ -16,8 +16,8 @@ "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. // Normalization parameters "signal_norm": true, // normalize the spec values in range [0, 1] - "symmetric_norm": false, // move normalization to range [-1, 1] - "max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "clip_norm": true, // clip normalized values into the range. "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! @@ -31,52 +31,53 @@ "reinit_layers": [], - "model": "Tacotron2", // one of the model in models/ + "model": "Tacotron", // one of the model in models/ "grad_clip": 1, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "windowing": false, // Enables attention windowing. Used only in eval mode. - "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. - "prenet_type": "original", // "original" or "bn". - "prenet_dropout": true, // enable/disable dropout at prenet. - "use_forward_attn": true, // enable/disable forward attention. In general, it aligns faster. - "forward_attn_mask": false, // Apply forward attention mask at inference to prevent bad modes. Try it if your model does not align well. - "transition_agent": false, // enable/disable transition agent of forward attention. - "location_attn": false, // enable_disable location sensitive attention. + "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". + "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. + "windowing": false, // Enables attention windowing. Used only in eval mode. + "use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "forward_attn_mask": false, + "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "location_attn": true, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "loss_masking": true, // enable / disable loss masking against the sequence padding. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "stopnet": true, // Train stopnet predicting the end of synthesis. "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, - "r": 1, // Number of frames to predict for step. + "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "gradual_training": [[0, 7, 32], [10000, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" - "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. - "print_step": 10, // Number of steps to log traning on console. + "save_step": 10000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 25, // Number of steps to log traning on console. "batch_group_size": 0, //Number of batches to shuffle after bucketing. "run_eval": true, "test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time. "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. - "data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument - "meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader. - "meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader. - "dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py - "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training + "data_path": "/home/erogol/Data/LJSpeech-1.1/", // DATASET-RELATED: can overwritten from command argument + "meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader. + "meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader. + "dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py + "min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training "max_seq_len": 150, // DATASET-RELATED: maximum text length - "output_path": "../keep/", // DATASET-RELATED: output path for all training outputs. + "output_path": "/media/erogol/data_ssd/Models/libri_tts/", // DATASET-RELATED: output path for all training outputs. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_val_loader_workers": 4, // number of evaluation data loader processes. "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "text_cleaner": "phoneme_cleaners", - "use_speaker_embedding": false // whether to use additional embeddings for separate speakers + "use_speaker_embedding": false } diff --git a/config_libritts.json b/config_libritts.json index 5579e565..658b9835 100644 --- a/config_libritts.json +++ b/config_libritts.json @@ -1,6 +1,6 @@ { "run_name": "libritts-360", - "run_description": "LibriTTS 360 clean with multi speaker embedding.", + "run_description": "LibriTTS 360 gradual traning with memory queue.", "audio":{ // Audio processing parameters @@ -31,13 +31,13 @@ "reinit_layers": [], - "model": "Tacotron2", // one of the model in models/ + "model": "Tacotron", // one of the model in models/ "grad_clip": 1, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "original", // "original" or "bn". "prenet_dropout": true, // enable/disable dropout at prenet. @@ -52,9 +52,9 @@ "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, - "r": 1, // Number of frames to predict for step. + "r": 7, // Number of frames to predict for step. "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. diff --git a/layers/common_layers.py b/layers/common_layers.py index 0a7216ef..98fc70ae 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -234,7 +234,7 @@ class Attention(nn.Module): query, processed_inputs) # apply masking if mask is not None: - attention.data.masked_fill_(1 - mask, self._mask_value) + attention.data.masked_fill_(torch.bitwise_not(mask), self._mask_value) # apply windowing - only in eval mode if not self.training and self.windowing: attention = self.apply_windowing(attention, inputs) diff --git a/layers/tacotron.py b/layers/tacotron.py index 40225fa5..d8d0e57a 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -135,9 +135,6 @@ class CBHG(nn.Module): ]) # max pooling of conv bank, with padding # TODO: try average pooling OR larger kernel size - self.max_pool1d = nn.Sequential( - nn.ConstantPad1d([0, 1], value=0), - nn.MaxPool1d(kernel_size=2, stride=1, padding=0)) out_features = [K * conv_bank_features] + conv_projections[:-1] activations = [self.relu] * (len(conv_projections) - 1) activations += [None] @@ -186,7 +183,6 @@ class CBHG(nn.Module): outs.append(out) x = torch.cat(outs, dim=1) assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) - x = self.max_pool1d(x) for conv1d in self.conv1d_projections: x = conv1d(x) # (B, T_in, hid_feature) @@ -270,23 +266,27 @@ class Decoder(nn.Module): memory_size (int): size of the past window. if <= 0 memory_size = r TODO: arguments """ + # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, separate_stopnet): + trans_agent, forward_attn_mask, location_attn, + separate_stopnet): super(Decoder, self).__init__() + self.r_init = r self.r = r self.in_features = in_features self.max_decoder_steps = 500 + self.use_memory_queue = memory_size > 0 self.memory_size = memory_size if memory_size > 0 else r self.memory_dim = memory_dim self.separate_stopnet = separate_stopnet self.query_dim = 256 # memory -> |Prenet| -> processed_memory self.prenet = Prenet( - memory_dim * self.memory_size, + memory_dim * self.memory_size if self.use_memory_queue else memory_dim, prenet_type, prenet_dropout, out_features=[256, 128]) @@ -311,21 +311,12 @@ class Decoder(nn.Module): self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec - self.proj_to_mel = nn.Linear(256, memory_dim * r) + self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init) # learn init values instead of zero init. - self.attention_rnn_init = nn.Embedding(1, 256) - self.memory_init = nn.Embedding(1, self.memory_size * memory_dim) - self.decoder_rnn_inits = nn.Embedding(2, 256) - self.stopnet = StopNet(256 + memory_dim * r) - # self.init_layers() + self.stopnet = StopNet(256 + memory_dim * self.r_init) - def init_layers(self): - torch.nn.init.xavier_uniform_( - self.project_to_decoder_in.weight, - gain=torch.nn.init.calculate_gain('linear')) - torch.nn.init.xavier_uniform_( - self.proj_to_mel.weight, - gain=torch.nn.init.calculate_gain('linear')) + def _set_r(self, new_r): + self.r = new_r def _reshape_memory(self, memory): """ @@ -347,13 +338,14 @@ class Decoder(nn.Module): B = inputs.size(0) T = inputs.size(1) # go frame as zeros matrix - self.memory_input = self.memory_init(inputs.data.new_zeros(B).long()) - + if self.use_memory_queue: + self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device) + else: + self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device) # decoder states - self.query = self.attention_rnn_init( - inputs.data.new_zeros(B).long()) + self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device) self.decoder_rnn_hiddens = [ - self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long()) + torch.zeros(B, 256, device=inputs.device) for idx in range(len(self.decoder_rnns)) ] self.context_vec = inputs.data.new(B, self.in_features).zero_() @@ -370,12 +362,13 @@ class Decoder(nn.Module): def decode(self, inputs, mask=None): # Prenet processed_memory = self.prenet(self.memory_input) - - # Attention - self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query) - self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask) - - # Concat query and attention context vector + # Attention RNN + self.attention_rnn_hidden = self.attention_rnn( + torch.cat((processed_memory, self.current_context_vec), -1), + self.attention_rnn_hidden) + self.context_vec = self.attention_layer( + self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( torch.cat((self.query, self.context_vec), -1)) @@ -389,25 +382,30 @@ class Decoder(nn.Module): # predict mel vectors from decoder vectors output = self.proj_to_mel(decoder_output) - output = torch.sigmoid(output) - + # output = torch.sigmoid(output) # predict stop token stopnet_input = torch.cat([decoder_output, output], -1) if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - return output, stop_token, self.attention.attention_weights + output = output[:, : self.r * self.memory_dim] + return output, stop_token, self.attention_layer.attention_weights - def _update_memory_queue(self, new_memory): - if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: - self.memory_input = torch.cat([ - self.memory_input[:, self.r * self.memory_dim:].clone(), - new_memory - ], - dim=-1) + def _update_memory_input(self, new_memory): + if self.use_memory_queue: + if self.memory_size > self.r: + # memory queue size is larger than number of frames per decoder iter + self.memory_input = torch.cat([ + new_memory, self.memory_input[:, :(self.memory_size - self.r) * self.memory_dim].clone() + ], + dim=-1) + else: + # memory queue size smaller than number of frames per decoder iter + self.memory_input = new_memory[:, :self.memory_size * self.memory_dim] else: - self.memory_input = new_memory + # use only the last frame prediction + self.memory_input = new_memory[:, :self.memory_dim] def forward(self, inputs, memory, mask): """ @@ -433,7 +431,7 @@ class Decoder(nn.Module): while len(outputs) < memory.size(0): if t > 0: new_memory = memory[t - 1] - self._update_memory_queue(new_memory) + self._update_memory_input(new_memory) output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] @@ -460,7 +458,7 @@ class Decoder(nn.Module): while True: if t > 0: new_memory = outputs[-1] - self._update_memory_queue(new_memory) + self._update_memory_input(new_memory) output, stop_token, attention = self.decode(inputs, None) stop_token = torch.sigmoid(stop_token.data) outputs += [output] diff --git a/models/tacotron.py b/models/tacotron.py index b7f40683..bf312db4 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -36,10 +36,8 @@ class Tacotron(nn.Module): forward_attn, trans_agent, forward_attn_mask, location_attn, separate_stopnet) self.postnet = PostCBHG(mel_dim) - self.last_linear = nn.Sequential( - nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), - nn.Sigmoid()) - + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim) + def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) diff --git a/train.py b/train.py index c893cb36..e28d7aa2 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, remove_experiment_folder, save_best_model, save_checkpoint, weight_decay, set_init_dict, copy_config_file, setup_model, - split_dataset) + split_dataset, gradual_training_scheduler) from utils.logger import Logger from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers @@ -82,7 +82,7 @@ def setup_loader(ap, is_val=False, verbose=False): def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, - ap, epoch): + ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) if c.use_speaker_embedding: speaker_mapping = load_speaker_mapping(OUT_PATH) @@ -92,8 +92,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, avg_decoder_loss = 0 avg_stop_loss = 0 avg_step_time = 0 + avg_loader_time = 0 print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) + end_time = time.time() for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -107,6 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, stop_targets = data[6] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) + loader_time = time.time() - end_time if c.use_speaker_embedding: speaker_ids = [speaker_mapping[speaker_name] @@ -120,8 +123,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) - current_step = num_iter + args.restore_step + \ - epoch * len(data_loader) + 1 + global_step += 1 # setup lr if c.lr_decay: @@ -176,18 +178,20 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, optimizer_st.step() else: grad_norm_st = 0 - + step_time = time.time() - start_time epoch_time += step_time - if current_step % c.print_step == 0: + if global_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " - "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format( - num_iter, batch_n_iter, current_step, loss.item(), + "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " + "LoaderTime:{:.2f} LR:{:.6f}".format( + num_iter, batch_n_iter, global_step, loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(), - grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), + grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, + loader_time, current_lr), flush=True) # aggregate losses from processes @@ -202,6 +206,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, avg_decoder_loss += float(decoder_loss.item()) avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()) avg_step_time += step_time + avg_loader_time += loader_time # Plot Training Iter Stats iter_stats = {"loss_posnet": postnet_loss.item(), @@ -210,13 +215,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, "step_time": step_time} - tb_logger.tb_train_iter_stats(current_step, iter_stats) + tb_logger.tb_train_iter_stats(global_step, iter_stats) - if current_step % c.save_step == 0: + if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint(model, optimizer, optimizer_st, - postnet_loss.item(), OUT_PATH, current_step, + postnet_loss.item(), OUT_PATH, global_step, epoch) # Diagnostic visualizations @@ -229,31 +234,34 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } - tb_logger.tb_train_figures(current_step, figures) + tb_logger.tb_train_figures(global_step, figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_mel_spectrogram(const_spec.T) - tb_logger.tb_train_audios(current_step, + tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) + end_time = time.time() avg_postnet_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1) avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss avg_step_time /= (num_iter + 1) + avg_loader_time /= (num_iter + 1) # print epoch stats print( " | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " - "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, + "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss, avg_postnet_loss, avg_decoder_loss, - avg_stop_loss, epoch_time, avg_step_time), + avg_stop_loss, epoch_time, avg_step_time, + avg_loader_time), flush=True) # Plot Epoch Stats @@ -263,14 +271,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "loss_decoder": avg_decoder_loss, "stop_loss": avg_stop_loss, "epoch_time": epoch_time} - tb_logger.tb_train_epoch_stats(current_step, epoch_stats) + tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: - tb_logger.tb_model_weights(model, current_step) - - return avg_postnet_loss, current_step + tb_logger.tb_model_weights(model, global_step) + return avg_postnet_loss, global_step -def evaluate(model, criterion, criterion_st, ap, current_step, epoch): +def evaluate(model, criterion, criterion_st, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=True) if c.use_speaker_embedding: speaker_mapping = load_speaker_mapping(OUT_PATH) @@ -383,14 +390,14 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): "ground_truth": plot_spectrogram(gt_spec, ap), "alignment": plot_alignment(align_img) } - tb_logger.tb_eval_figures(current_step, eval_figures) + tb_logger.tb_eval_figures(global_step, eval_figures) # Sample audio if c.model in ["Tacotron", "TacotronGST"]: eval_audio = ap.inv_spectrogram(const_spec.T) else: eval_audio = ap.inv_mel_spectrogram(const_spec.T) - tb_logger.tb_eval_audios(current_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # compute average losses avg_postnet_loss /= (num_iter + 1) @@ -401,7 +408,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): epoch_stats = {"loss_postnet": avg_postnet_loss, "loss_decoder": avg_decoder_loss, "stop_loss": avg_stop_loss} - tb_logger.tb_eval_stats(current_step, epoch_stats) + tb_logger.tb_eval_stats(global_step, epoch_stats) if args.rank == 0 and epoch > c.test_delay_epochs: # test sentences @@ -427,8 +434,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): except: print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate']) - tb_logger.tb_test_figures(current_step, test_figures) + tb_logger.tb_test_audios(global_step, test_audios, c.audio['sample_rate']) + tb_logger.tb_test_figures(global_step, test_figures) return avg_postnet_loss @@ -526,11 +533,19 @@ def main(args): #pylint: disable=redefined-outer-name if 'best_loss' not in locals(): best_loss = float('inf') + global_step = args.restore_step for epoch in range(0, c.epochs): - train_loss, current_step = train(model, criterion, criterion_st, + # set gradual training + if c.gradual_training is not None: + r, c.batch_size = gradual_training_scheduler(global_step, c) + c.r = r + model.decoder._set_r(r) + print(" > Number of outputs per iteration:", model.decoder.r) + + train_loss, global_step = train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, - ap, epoch) - val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch) + ap, global_step, epoch) + val_loss = evaluate(model, criterion, criterion_st, ap, global_step, epoch) print( " | > Training Loss: {:.5f} Validation Loss: {:.5f}".format( train_loss, val_loss), @@ -539,7 +554,7 @@ def main(args): #pylint: disable=redefined-outer-name if c.run_eval: target_loss = val_loss best_loss = save_best_model(model, optimizer, target_loss, best_loss, - OUT_PATH, current_step, epoch) + OUT_PATH, global_step, epoch) if __name__ == '__main__': @@ -573,7 +588,7 @@ if __name__ == '__main__': '--output_folder', type=str, default='', - help='folder name for traning outputs.' + help='folder name for training outputs.' ) # DISTRUBUTED diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 6cf4f420..51155254 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } torch.save(state, checkpoint_path) @@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, 'step': current_step, 'epoch': epoch, 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y") + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': model.decoder.r } best_loss = model_loss bestmodel_path = 'best_model.pth.tar' @@ -305,3 +307,10 @@ def split_dataset(items): else: return items[:eval_split_size], items[eval_split_size:] + +def gradual_training_scheduler(global_step, config): + new_values = None + for values in config.gradual_training: + if global_step >= values[0]: + new_values = values + return new_values[1], new_values[2] \ No newline at end of file