Init data_loaders by function beginning of each train and eval run

pull/10/head
Eren Golge 2018-12-11 17:52:43 +01:00
parent dc3d09304e
commit 619c73f0f1
2 changed files with 47 additions and 62 deletions

View File

@ -49,5 +49,6 @@
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default. "dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
"min_seq_len": 0, "min_seq_len": 0,
"output_path": "../keep/", "output_path": "../keep/",
"num_loader_workers": 8 "num_loader_workers": 8,
"num_val_loader_workers": 4
} }

106
train.py
View File

@ -29,9 +29,35 @@ print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", torch.cuda.device_count()) print(" > Number of GPUs: ", torch.cuda.device_count())
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, def setup_loader(is_val=False):
global ap
if is_val and not c.run_eval:
loader = None
else:
dataset = MyDataset(
c.data_path,
c.meta_file_val if is_val else c.meta_file_train,
c.r,
c.text_cleaner,
preprocessor=preprocessor,
ap=ap,
batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len)
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def train(model, criterion, criterion_st, optimizer, optimizer_st,
scheduler, ap, epoch): scheduler, ap, epoch):
model = model.train() data_loader = setup_loader(is_val=False)
model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
avg_mel_loss = 0 avg_mel_loss = 0
@ -212,8 +238,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
return avg_linear_loss, current_step return avg_linear_loss, current_step
def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): def evaluate(model, criterion, criterion_st, ap, current_step):
model = model.eval() data_loader = setup_loader(is_val=True)
model.eval()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
avg_mel_loss = 0 avg_mel_loss = 0
@ -361,58 +388,6 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
def main(args): def main(args):
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
MyDataset = importlib.import_module('datasets.' + c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor')
# Audio processor
ap = AudioProcessor(**c.audio)
# Setup the dataset
train_dataset = MyDataset(
c.data_path,
c.meta_file_train,
c.r,
c.text_cleaner,
preprocessor=preprocessor,
ap=ap,
batch_group_size=8 * c.batch_size,
min_seq_len=c.min_seq_len)
train_loader = DataLoader(
train_dataset,
batch_size=c.batch_size,
shuffle=False,
collate_fn=train_dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers,
pin_memory=True)
if c.run_eval:
val_dataset = MyDataset(
c.data_path,
c.meta_file_val,
c.r,
c.text_cleaner,
preprocessor=preprocessor,
ap=ap,
batch_group_size=0)
val_loader = DataLoader(
val_dataset,
batch_size=c.eval_batch_size,
shuffle=False,
collate_fn=val_dataset.collate_fn,
drop_last=False,
num_workers=4,
pin_memory=True)
else:
val_loader = None
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r) model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
@ -433,7 +408,7 @@ def main(args):
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
print( print(
" > Model restored from step %d" % checkpoint['step'], flush=True) " > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['step'] // len(train_loader) start_epoch = checkpoint['epoch']
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['linear_loss']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
else: else:
@ -463,9 +438,9 @@ def main(args):
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st, train_loss, current_step = train(model, criterion, criterion_st,
train_loader, optimizer, optimizer_st, optimizer, optimizer_st,
scheduler, ap, epoch) scheduler, ap, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, val_loss = evaluate(model, criterion, criterion_st, ap,
current_step) current_step)
print( print(
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format( " | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
@ -473,8 +448,6 @@ def main(args):
flush=True) flush=True)
best_loss = save_best_model(model, optimizer, train_loss, best_loss, best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch) OUT_PATH, current_step, epoch)
# shuffle batch groups
train_loader.dataset.sort_items()
if __name__ == '__main__': if __name__ == '__main__':
@ -515,6 +488,17 @@ if __name__ == '__main__':
LOG_DIR = OUT_PATH LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR) tb = SummaryWriter(LOG_DIR)
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
MyDataset = importlib.import_module('datasets.' + c.data_loader)
MyDataset = getattr(MyDataset, "MyDataset")
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
AudioProcessor = getattr(audio, 'AudioProcessor')
# Audio processor
ap = AudioProcessor(**c.audio)
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt: