From 619c73f0f17a8987857b9bc140ed1d0cb1c9353c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 11 Dec 2018 17:52:43 +0100 Subject: [PATCH] Init data_loaders by function beginning of each train and eval run --- config.json | 3 +- train.py | 106 ++++++++++++++++++++++------------------------------ 2 files changed, 47 insertions(+), 62 deletions(-) diff --git a/config.json b/config.json index 9d66b45d..05ab24f8 100644 --- a/config.json +++ b/config.json @@ -49,5 +49,6 @@ "dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default. "min_seq_len": 0, "output_path": "../keep/", - "num_loader_workers": 8 + "num_loader_workers": 8, + "num_val_loader_workers": 4 } \ No newline at end of file diff --git a/train.py b/train.py index 769b4f6c..54d6140f 100644 --- a/train.py +++ b/train.py @@ -29,9 +29,35 @@ print(" > Using CUDA: ", use_cuda) print(" > Number of GPUs: ", torch.cuda.device_count()) -def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, +def setup_loader(is_val=False): + global ap + if is_val and not c.run_eval: + loader = None + else: + dataset = MyDataset( + c.data_path, + c.meta_file_val if is_val else c.meta_file_train, + c.r, + c.text_cleaner, + preprocessor=preprocessor, + ap=ap, + batch_group_size=0 if is_val else 8 * c.batch_size, + min_seq_len=0 if is_val else c.min_seq_len) + loader = DataLoader( + dataset, + batch_size=c.eval_batch_size if is_val else c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False) + return loader + + +def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch): - model = model.train() + data_loader = setup_loader(is_val=False) + model.train() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 @@ -212,8 +238,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, return avg_linear_loss, current_step -def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): - model = model.eval() +def evaluate(model, criterion, criterion_st, ap, current_step): + data_loader = setup_loader(is_val=True) + model.eval() epoch_time = 0 avg_linear_loss = 0 avg_mel_loss = 0 @@ -361,58 +388,6 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): def main(args): - # Conditional imports - preprocessor = importlib.import_module('datasets.preprocess') - preprocessor = getattr(preprocessor, c.dataset.lower()) - MyDataset = importlib.import_module('datasets.' + c.data_loader) - MyDataset = getattr(MyDataset, "MyDataset") - audio = importlib.import_module('utils.' + c.audio['audio_processor']) - AudioProcessor = getattr(audio, 'AudioProcessor') - - # Audio processor - ap = AudioProcessor(**c.audio) - - # Setup the dataset - train_dataset = MyDataset( - c.data_path, - c.meta_file_train, - c.r, - c.text_cleaner, - preprocessor=preprocessor, - ap=ap, - batch_group_size=8 * c.batch_size, - min_seq_len=c.min_seq_len) - - train_loader = DataLoader( - train_dataset, - batch_size=c.batch_size, - shuffle=False, - collate_fn=train_dataset.collate_fn, - drop_last=False, - num_workers=c.num_loader_workers, - pin_memory=True) - - if c.run_eval: - val_dataset = MyDataset( - c.data_path, - c.meta_file_val, - c.r, - c.text_cleaner, - preprocessor=preprocessor, - ap=ap, - batch_group_size=0) - - val_loader = DataLoader( - val_dataset, - batch_size=c.eval_batch_size, - shuffle=False, - collate_fn=val_dataset.collate_fn, - drop_last=False, - num_workers=4, - pin_memory=True) - else: - val_loader = None - model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r) print(" | > Num output units : {}".format(ap.num_freq), flush=True) @@ -433,7 +408,7 @@ def main(args): optimizer.load_state_dict(checkpoint['optimizer']) print( " > Model restored from step %d" % checkpoint['step'], flush=True) - start_epoch = checkpoint['step'] // len(train_loader) + start_epoch = checkpoint['epoch'] best_loss = checkpoint['linear_loss'] args.restore_step = checkpoint['step'] else: @@ -463,9 +438,9 @@ def main(args): for epoch in range(0, c.epochs): train_loss, current_step = train(model, criterion, criterion_st, - train_loader, optimizer, optimizer_st, + optimizer, optimizer_st, scheduler, ap, epoch) - val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, + val_loss = evaluate(model, criterion, criterion_st, ap, current_step) print( " | > Train Loss: {:.5f} Validation Loss: {:.5f}".format( @@ -473,8 +448,6 @@ def main(args): flush=True) best_loss = save_best_model(model, optimizer, train_loss, best_loss, OUT_PATH, current_step, epoch) - # shuffle batch groups - train_loader.dataset.sort_items() if __name__ == '__main__': @@ -515,6 +488,17 @@ if __name__ == '__main__': LOG_DIR = OUT_PATH tb = SummaryWriter(LOG_DIR) + # Conditional imports + preprocessor = importlib.import_module('datasets.preprocess') + preprocessor = getattr(preprocessor, c.dataset.lower()) + MyDataset = importlib.import_module('datasets.' + c.data_loader) + MyDataset = getattr(MyDataset, "MyDataset") + audio = importlib.import_module('utils.' + c.audio['audio_processor']) + AudioProcessor = getattr(audio, 'AudioProcessor') + + # Audio processor + ap = AudioProcessor(**c.audio) + try: main(args) except KeyboardInterrupt: