mirror of https://github.com/coqui-ai/TTS.git
Update `align_tts` for trainer_v2
parent
8ada870a57
commit
d9df33f837
|
@ -103,7 +103,7 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
def __init__(self, config: Coqpit):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.phase = -1
|
||||
self.length_scale = (
|
||||
|
@ -360,9 +360,7 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(
|
||||
self, ap: AudioProcessor, batch: dict, outputs: dict
|
||||
) -> Tuple[Dict, Dict]: # pylint: disable=no-self-use
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
@ -381,11 +379,22 @@ class AlignTTS(BaseTTS):
|
|||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.configs import AlignTTSConfig, BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.align_tts import AlignTTS
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# init configs
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
|
@ -25,6 +30,24 @@ config = AlignTTSConfig(
|
|||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# init model
|
||||
model = AlignTTS(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
||||
|
|
Loading…
Reference in New Issue