Update AlignTTS

pull/1324/head
Eren Gölge 2021-12-07 12:56:24 +00:00
parent 18f726af65
commit bacf79f4fb
1 changed files with 30 additions and 13 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union
import torch
from coqpit import Coqpit
@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -100,11 +102,16 @@ class AlignTTS(BaseTTS):
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: "AlignTTSConfig",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config)
super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
self.config = config
self.phase = -1
self.length_scale = (
float(config.model_args.length_scale)
@ -112,10 +119,6 @@ class AlignTTS(BaseTTS):
else config.model_args.length_scale
)
if not self.config.model_args.num_chars:
_, self.config, num_chars = self.get_characters(config)
self.config.model_args.num_chars = num_chars
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
self.embedded_speaker_dim = 0
@ -382,19 +385,17 @@ class AlignTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self, config, checkpoint_path, eval=False
@ -430,3 +431,19 @@ class AlignTTS(BaseTTS):
def on_epoch_start(self, trainer):
"""Set AlignTTS training phase on epoch start."""
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
@staticmethod
def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (AlignTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return AlignTTS(new_config, ap, tokenizer, speaker_manager)