mirror of https://github.com/coqui-ai/TTS.git
Init data_loaders by function beginning of each train and eval run
parent
dc3d09304e
commit
619c73f0f1
|
@ -49,5 +49,6 @@
|
||||||
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
|
"dataset": "ljspeech", // one of TTS.dataset.preprocessors, only valid id dataloader == "TTSDataset", rest uses "tts_cache" by default.
|
||||||
"min_seq_len": 0,
|
"min_seq_len": 0,
|
||||||
"output_path": "../keep/",
|
"output_path": "../keep/",
|
||||||
"num_loader_workers": 8
|
"num_loader_workers": 8,
|
||||||
|
"num_val_loader_workers": 4
|
||||||
}
|
}
|
106
train.py
106
train.py
|
@ -29,9 +29,35 @@ print(" > Using CUDA: ", use_cuda)
|
||||||
print(" > Number of GPUs: ", torch.cuda.device_count())
|
print(" > Number of GPUs: ", torch.cuda.device_count())
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
def setup_loader(is_val=False):
|
||||||
|
global ap
|
||||||
|
if is_val and not c.run_eval:
|
||||||
|
loader = None
|
||||||
|
else:
|
||||||
|
dataset = MyDataset(
|
||||||
|
c.data_path,
|
||||||
|
c.meta_file_val if is_val else c.meta_file_train,
|
||||||
|
c.r,
|
||||||
|
c.text_cleaner,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
ap=ap,
|
||||||
|
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||||
|
min_seq_len=0 if is_val else c.min_seq_len)
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||||
|
pin_memory=False)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
scheduler, ap, epoch):
|
scheduler, ap, epoch):
|
||||||
model = model.train()
|
data_loader = setup_loader(is_val=False)
|
||||||
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
|
@ -212,8 +238,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
def evaluate(model, criterion, criterion_st, ap, current_step):
|
||||||
model = model.eval()
|
data_loader = setup_loader(is_val=True)
|
||||||
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
|
@ -361,58 +388,6 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# Conditional imports
|
|
||||||
preprocessor = importlib.import_module('datasets.preprocess')
|
|
||||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
|
||||||
MyDataset = importlib.import_module('datasets.' + c.data_loader)
|
|
||||||
MyDataset = getattr(MyDataset, "MyDataset")
|
|
||||||
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
|
||||||
|
|
||||||
# Audio processor
|
|
||||||
ap = AudioProcessor(**c.audio)
|
|
||||||
|
|
||||||
# Setup the dataset
|
|
||||||
train_dataset = MyDataset(
|
|
||||||
c.data_path,
|
|
||||||
c.meta_file_train,
|
|
||||||
c.r,
|
|
||||||
c.text_cleaner,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
ap=ap,
|
|
||||||
batch_group_size=8 * c.batch_size,
|
|
||||||
min_seq_len=c.min_seq_len)
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=c.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=train_dataset.collate_fn,
|
|
||||||
drop_last=False,
|
|
||||||
num_workers=c.num_loader_workers,
|
|
||||||
pin_memory=True)
|
|
||||||
|
|
||||||
if c.run_eval:
|
|
||||||
val_dataset = MyDataset(
|
|
||||||
c.data_path,
|
|
||||||
c.meta_file_val,
|
|
||||||
c.r,
|
|
||||||
c.text_cleaner,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
ap=ap,
|
|
||||||
batch_group_size=0)
|
|
||||||
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=c.eval_batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=val_dataset.collate_fn,
|
|
||||||
drop_last=False,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True)
|
|
||||||
else:
|
|
||||||
val_loader = None
|
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
model = Tacotron(c.embedding_size, ap.num_freq, ap.num_mels, c.r)
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
|
@ -433,7 +408,7 @@ def main(args):
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
print(
|
print(
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['epoch']
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
|
@ -463,9 +438,9 @@ def main(args):
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_st,
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
train_loader, optimizer, optimizer_st,
|
optimizer, optimizer_st,
|
||||||
scheduler, ap, epoch)
|
scheduler, ap, epoch)
|
||||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
val_loss = evaluate(model, criterion, criterion_st, ap,
|
||||||
current_step)
|
current_step)
|
||||||
print(
|
print(
|
||||||
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
|
@ -473,8 +448,6 @@ def main(args):
|
||||||
flush=True)
|
flush=True)
|
||||||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||||
OUT_PATH, current_step, epoch)
|
OUT_PATH, current_step, epoch)
|
||||||
# shuffle batch groups
|
|
||||||
train_loader.dataset.sort_items()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -515,6 +488,17 @@ if __name__ == '__main__':
|
||||||
LOG_DIR = OUT_PATH
|
LOG_DIR = OUT_PATH
|
||||||
tb = SummaryWriter(LOG_DIR)
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
|
# Conditional imports
|
||||||
|
preprocessor = importlib.import_module('datasets.preprocess')
|
||||||
|
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||||
|
MyDataset = importlib.import_module('datasets.' + c.data_loader)
|
||||||
|
MyDataset = getattr(MyDataset, "MyDataset")
|
||||||
|
audio = importlib.import_module('utils.' + c.audio['audio_processor'])
|
||||||
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
|
# Audio processor
|
||||||
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
Loading…
Reference in New Issue