trainer-API updates

pull/506/head
Eren Gölge 2021-05-20 18:23:53 +02:00
parent 42554cc711
commit 8def3c87af
13 changed files with 62 additions and 31 deletions

View File

@ -229,7 +229,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0:
if config.checkpoint:

View File

@ -270,7 +270,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0:
if config.checkpoint:

View File

@ -256,7 +256,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
if global_step % config.tb_plot_step == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0:
if config.checkpoint:

View File

@ -327,7 +327,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
"step_time": step_time,
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
if global_step % config.save_step == 0:
if config.checkpoint:

View File

@ -265,7 +265,7 @@ def train(
if global_step % 10 == 0:
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:

View File

@ -181,7 +181,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch
if global_step % 10 == 0:
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:

View File

@ -163,7 +163,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
if global_step % 10 == 0:
iter_stats = {"lr": cur_lr, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
tb_logger.tb_train_step_stats(global_step, iter_stats)
# save checkpoint
if global_step % c.save_step == 0:

View File

@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig):
datasets (List[BaseDatasetConfig]):
List of datasets used for training. If multiple datasets are provided, they are merged and used together
for training.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to ``.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to ``.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
test_sentences (List[str]):
List of sentences to be used at testing. Defaults to '[]'
"""
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
@ -158,3 +170,11 @@ class BaseTTSConfig(BaseTrainingConfig):
add_blank: bool = False
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = MISSING
optimizer_params: dict = MISSING
# scheduler
lr_scheduler: str = ''
lr_scheduler_params: dict = field(default_factory=lambda: {})
# testing
test_sentences: List[str] = field(default_factory=lambda:[])

View File

@ -78,10 +78,16 @@ class TacotronConfig(BaseTTSConfig):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
external_speaker_embedding_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
optimizer (str):
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
Defaults to `RAdam`.
optimizer_params (dict):
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
lr_scheduler (str):
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
`TTS.utils.training`. Defaults to `NoamLR`.
lr_scheduler_params (dict):
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-4`.
wd (float):
@ -152,10 +158,11 @@ class TacotronConfig(BaseTTSConfig):
external_speaker_embedding_file: str = False
# optimizer parameters
noam_schedule: bool = False
warmup_steps: int = 4000
optimizer: str = "RAdam"
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
lr: float = 1e-4
wd: float = 1e-6
grad_clip: float = 5.0
seq_len_norm: bool = False
loss_masking: bool = True

View File

@ -1,7 +1,7 @@
import json
import os
import random
from typing import Union
from typing import Union, List, Any
import numpy as np
import torch
@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items}
return sorted(speakers)
def parse_speakers(c, args, meta_data_train, OUT_PATH):
@ -121,26 +119,31 @@ class SpeakerManager:
Args:
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
TTS model. Defaults to "".
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
TTS models. Defaults to "".
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
"""
def __init__(
self,
data_items: List[List[Any]] = None,
x_vectors_file_path: str = "",
speaker_id_file_path: str = "",
encoder_model_path: str = "",
encoder_config_path: str = "",
):
self.x_vectors = None
self.speaker_ids = None
self.clip_ids = None
self.data_items = []
self.x_vectors = []
self.speaker_ids = []
self.clip_ids = []
self.speaker_encoder = None
self.speaker_encoder_ap = None
if data_items:
self.speaker_ids = self.parse_speakers()
if x_vectors_file_path:
self.load_x_vectors_file(x_vectors_file_path)
@ -169,10 +172,10 @@ class SpeakerManager:
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
def parser_speakers_from_items(self, items: list):
speaker_ids = sorted({item[2] for item in items})
self.speaker_ids = speaker_ids
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
speakers = sorted({item[2] for item in items})
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(self.speaker_ids)
return self.speaker_ids, num_speakers
def save_ids_file(self, file_path: str):
self._save_json(file_path, self.speaker_ids)

View File

@ -65,7 +65,7 @@ def basic_cleaners(text):
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = lowercase(text)
text = expand_time_english(text)
text = expand_numbers(text)
@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
def phoneme_cleaners(text):
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
text = convert_to_ascii(text)
# text = convert_to_ascii(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)

View File

@ -39,7 +39,7 @@ class TensorboardLogger(object):
except RuntimeError:
traceback.print_exc()
def tb_train_iter_stats(self, step, stats):
def tb_train_step_stats(self, step, stats):
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
def tb_train_epoch_stats(self, step, stats):

View File

@ -21,6 +21,7 @@ config = MelganConfig(
print_step=1,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
print_eval=True,
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
data_path="tests/data/ljspeech",
output_path=output_path,
)