mirror of https://github.com/coqui-ai/TTS.git
trainer-API updates
parent
42554cc711
commit
8def3c87af
|
@ -229,7 +229,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
|||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
|
|
|
@ -270,7 +270,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
|||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
|
|
|
@ -256,7 +256,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step,
|
|||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
|
|
|
@ -327,7 +327,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap,
|
|||
"step_time": step_time,
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
|
|
|
@ -265,7 +265,7 @@ def train(
|
|||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
|
|
|
@ -181,7 +181,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch
|
|||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
|
|
|
@ -163,7 +163,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
|
|
|
@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
|
@ -158,3 +170,11 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
add_blank: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = MISSING
|
||||
optimizer_params: dict = MISSING
|
||||
# scheduler
|
||||
lr_scheduler: str = ''
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda:[])
|
||||
|
|
|
@ -78,10 +78,16 @@ class TacotronConfig(BaseTTSConfig):
|
|||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
warmup_steps (int):
|
||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to `RAdam`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to `NoamLR`.
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to `1e-4`.
|
||||
wd (float):
|
||||
|
@ -152,10 +158,11 @@ class TacotronConfig(BaseTTSConfig):
|
|||
external_speaker_embedding_file: str = False
|
||||
|
||||
# optimizer parameters
|
||||
noam_schedule: bool = False
|
||||
warmup_steps: int = 4000
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {'betas': [0.9, 0.998], 'weight_decay': 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda:{"warmup_steps": 4000})
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
grad_clip: float = 5.0
|
||||
seq_len_norm: bool = False
|
||||
loss_masking: bool = True
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
from typing import Union, List, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -35,9 +35,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
|||
|
||||
|
||||
def get_speakers(items):
|
||||
"""Returns a sorted, unique list of speakers in a given dataset."""
|
||||
speakers = {e[2] for e in items}
|
||||
return sorted(speakers)
|
||||
|
||||
|
||||
|
||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||
|
@ -121,26 +119,31 @@ class SpeakerManager:
|
|||
|
||||
Args:
|
||||
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
|
||||
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
|
||||
TTS model. Defaults to "".
|
||||
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
|
||||
TTS models. Defaults to "".
|
||||
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
|
||||
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_items: List[List[Any]] = None,
|
||||
x_vectors_file_path: str = "",
|
||||
speaker_id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
encoder_config_path: str = "",
|
||||
):
|
||||
|
||||
self.x_vectors = None
|
||||
self.speaker_ids = None
|
||||
self.clip_ids = None
|
||||
self.data_items = []
|
||||
self.x_vectors = []
|
||||
self.speaker_ids = []
|
||||
self.clip_ids = []
|
||||
self.speaker_encoder = None
|
||||
self.speaker_encoder_ap = None
|
||||
|
||||
if data_items:
|
||||
self.speaker_ids = self.parse_speakers()
|
||||
|
||||
if x_vectors_file_path:
|
||||
self.load_x_vectors_file(x_vectors_file_path)
|
||||
|
||||
|
@ -169,10 +172,10 @@ class SpeakerManager:
|
|||
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
|
||||
|
||||
def parser_speakers_from_items(self, items: list):
|
||||
speaker_ids = sorted({item[2] for item in items})
|
||||
self.speaker_ids = speaker_ids
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
speakers = sorted({item[2] for item in items})
|
||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(self.speaker_ids)
|
||||
return self.speaker_ids, num_speakers
|
||||
|
||||
def save_ids_file(self, file_path: str):
|
||||
self._save_json(file_path, self.speaker_ids)
|
||||
|
|
|
@ -65,7 +65,7 @@ def basic_cleaners(text):
|
|||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):
|
|||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_time_english(text)
|
||||
text = expand_numbers(text)
|
||||
|
@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
|
|||
def phoneme_cleaners(text):
|
||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||
text = expand_numbers(text)
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
|
|
|
@ -39,7 +39,7 @@ class TensorboardLogger(object):
|
|||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
def tb_train_iter_stats(self, step, stats):
|
||||
def tb_train_step_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||
|
||||
def tb_train_epoch_stats(self, step, stats):
|
||||
|
|
|
@ -21,6 +21,7 @@ config = MelganConfig(
|
|||
print_step=1,
|
||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||
print_eval=True,
|
||||
discriminator_model_params={"base_channels": 16, "max_channels": 256, "downsample_factors": [4, 4, 4]},
|
||||
data_path="tests/data/ljspeech",
|
||||
output_path=output_path,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue