mirror of https://github.com/coqui-ai/TTS.git
setup training scripts for computing phonemes before training optionally. And define data_loaders before starting training and re-use them instead of re-define for every train and eval calls. This is to enable better instance filtering based on input length.
parent
7c3cdced1a
commit
affe1c1138
|
@ -59,6 +59,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -112,7 +118,7 @@ def format_data(data):
|
|||
avg_text_length, avg_spec_length, attn_mask, item_idx
|
||||
|
||||
|
||||
def data_depended_init(model, ap):
|
||||
def data_depended_init(data_loader, model, ap):
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, 'module'):
|
||||
for f in model.module.decoder.flows:
|
||||
|
@ -123,7 +129,6 @@ def data_depended_init(model, ap):
|
|||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
|
||||
data_loader = setup_loader(ap, 1, is_val=False)
|
||||
model.train()
|
||||
print(" > Data depended initialization ... ")
|
||||
num_iter = 0
|
||||
|
@ -152,10 +157,9 @@ def data_depended_init(model, ap):
|
|||
return model
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler,
|
||||
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, 1, is_val=False,
|
||||
verbose=(epoch == 0))
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -308,8 +312,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, 1, is_val=True)
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -533,14 +536,18 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
model = data_depended_init(model, ap)
|
||||
model = data_depended_init(train_loader, model, ap)
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
|
|
|
@ -61,6 +61,12 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -123,10 +129,8 @@ def format_data(data):
|
|||
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
||||
ap, global_step, epoch, scaler, scaler_st):
|
||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -324,8 +328,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -583,6 +586,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
# define data loaders
|
||||
train_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=False,
|
||||
verbose=True)
|
||||
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
|
@ -594,16 +604,27 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.bidirectional_decoder:
|
||||
model.decoder_backward.set_r(r)
|
||||
print("\n > Number of output frames:", model.decoder.r)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
optimizer_st, scheduler, ap,
|
||||
global_step, epoch, scaler, scaler_st)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
global_step, epoch, scaler,
|
||||
scaler_st)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH, scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -131,6 +131,7 @@
|
|||
"batch_group_size": 4, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 153, // DATASET-RELATED: maximum text length
|
||||
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
|
|
@ -105,6 +105,7 @@
|
|||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
|
|
@ -63,6 +63,7 @@ class MyDataset(Dataset):
|
|||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_mapping = speaker_mapping
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if self.verbose:
|
||||
|
@ -71,7 +72,6 @@ class MyDataset(Dataset):
|
|||
if use_phonemes:
|
||||
print(" | > phoneme language: {}".format(phoneme_language))
|
||||
print(" | > Number of instances : {}".format(len(self.items)))
|
||||
self.sort_items()
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename)
|
||||
|
|
Loading…
Reference in New Issue