From 8ff9253abd027af0e5cefc85cd878baa1d13295d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 17 Dec 2018 16:37:06 +0100 Subject: [PATCH] bug fixes on train.py --- train.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 783cf5b8..7dd5b78d 100644 --- a/train.py +++ b/train.py @@ -43,7 +43,9 @@ def setup_loader(is_val=False): preprocessor=preprocessor, ap=ap, batch_group_size=0 if is_val else 8 * c.batch_size, - min_seq_len=0 if is_val else c.min_seq_len) + min_seq_len=0 if is_val else c.min_seq_len, + max_seq_len=float("inf") if is_val else c.max_seq_len + cached=False if c.dataset ~= "tts_cache" else True) loader = DataLoader( dataset, batch_size=c.eval_batch_size if is_val else c.batch_size, @@ -164,8 +166,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr), flush=True) - avg_linear_loss += linear_loss.item() - avg_mel_loss += mel_loss.item() + avg_linear_loss += float(linear_loss.item()) + avg_mel_loss += float(mel_loss.item()) avg_stop_loss += stop_loss.item() avg_step_time += step_time @@ -198,7 +200,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, # Sample audio tb_logger.tb_train_audios(current_step, {'TrainAudio': ap.inv_spectrogram(const_spec.T)}, - c.sample_rate) + c.audio["sample_rate"]) avg_linear_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1) @@ -295,8 +297,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step): stop_loss.item()), flush=True) - avg_linear_loss += linear_loss.item() - avg_mel_loss += mel_loss.item() + avg_linear_loss += float(linear_loss.item()) + avg_mel_loss += float(mel_loss.item()) avg_stop_loss += stop_loss.item() # Diagnostic visualizations @@ -442,7 +444,7 @@ if __name__ == '__main__': default=False, help='Do not verify commit integrity to run training.') parser.add_argument( - '--data_path', type=str, help='dataset path.', default='Defines the data path. It overwrites config.json.') + '--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.') args = parser.parse_args() # setup output paths and read configs