setup training scripts for computing phonemes before training optionally. And define data_loaders before starting training and re-use them instead of re-define for every train and eval calls. This is to enable better instance filtering based on input length.

pull/10/head
erogol 2020-12-07 11:26:57 +01:00
parent 7c3cdced1a
commit affe1c1138
5 changed files with 51 additions and 21 deletions

View File

@ -59,6 +59,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
@ -112,7 +118,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask, item_idx
def data_depended_init(model, ap):
def data_depended_init(data_loader, model, ap):
"""Data depended initialization for activation normalization."""
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
@ -123,7 +129,6 @@ def data_depended_init(model, ap):
if getattr(f, "set_ddi", False):
f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False)
model.train()
print(" > Data depended initialization ... ")
num_iter = 0
@ -152,10 +157,9 @@ def data_depended_init(model, ap):
return model
def train(model, criterion, optimizer, scheduler,
def train(data_loader, model, criterion, optimizer, scheduler,
ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
@ -308,8 +312,7 @@ def train(model, criterion, optimizer, scheduler,
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=True)
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
@ -533,14 +536,18 @@ def main(args): # pylint: disable=redefined-outer-name
if 'best_loss' not in locals():
best_loss = float('inf')
# define dataloaders
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
global_step = args.restore_step
model = data_depended_init(model, ap)
model = data_depended_init(train_loader, model, ap)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:

View File

@ -61,6 +61,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
@ -123,10 +129,8 @@ def format_data(data):
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
def train(model, criterion, optimizer, optimizer_st, scheduler,
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch, scaler, scaler_st):
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
epoch_time = 0
keep_avg = KeepAverage()
@ -324,8 +328,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
@ -583,6 +586,13 @@ def main(args): # pylint: disable=redefined-outer-name
if 'best_loss' not in locals():
best_loss = float('inf')
# define data loaders
train_loader = setup_loader(ap,
model.decoder.r,
is_val=False,
verbose=True)
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
@ -594,16 +604,27 @@ def main(args): # pylint: disable=redefined-outer-name
if c.bidirectional_decoder:
model.decoder_backward.set_r(r)
print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
train_avg_loss_dict, global_step = train(train_loader, model,
criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch, scaler, scaler_st)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
global_step, epoch, scaler,
scaler_st)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
best_loss = save_best_model(
target_loss,
best_loss,
model,
optimizer,
global_step,
epoch,
c.r,
OUT_PATH,
scaler=scaler.state_dict() if c.mixed_precision else None)
if __name__ == '__main__':

View File

@ -131,6 +131,7 @@
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 153, // DATASET-RELATED: maximum text length
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
// PATHS
"output_path": "/home/erogol/Models/LJSpeech/",

View File

@ -105,6 +105,7 @@
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 500, // DATASET-RELATED: maximum text length
"compute_f0": false, // compute f0 values in data-loader
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
// PATHS
"output_path": "/home/erogol/Models/LJSpeech/",

View File

@ -63,6 +63,7 @@ class MyDataset(Dataset):
self.enable_eos_bos = enable_eos_bos
self.speaker_mapping = speaker_mapping
self.verbose = verbose
self.input_seq_computed = False
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
if self.verbose:
@ -71,7 +72,6 @@ class MyDataset(Dataset):
if use_phonemes:
print(" | > phoneme language: {}".format(phoneme_language))
print(" | > Number of instances : {}".format(len(self.items)))
self.sort_items()
def load_wav(self, filename):
audio = self.ap.load_wav(filename)