mirror of https://github.com/coqui-ai/TTS.git
Update AlignTTS
parent
18f726af65
commit
bacf79f4fb
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -12,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -100,11 +102,16 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: "AlignTTSConfig",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
self.speaker_manager = speaker_manager
|
||||
self.config = config
|
||||
self.phase = -1
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
|
@ -112,10 +119,6 @@ class AlignTTS(BaseTTS):
|
|||
else config.model_args.length_scale
|
||||
)
|
||||
|
||||
if not self.config.model_args.num_chars:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.model_args.num_chars = num_chars
|
||||
|
||||
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
||||
|
||||
self.embedded_speaker_dim = 0
|
||||
|
@ -382,19 +385,17 @@ class AlignTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -430,3 +431,19 @@ class AlignTTS(BaseTTS):
|
|||
def on_epoch_start(self, trainer):
|
||||
"""Set AlignTTS training phase on epoch start."""
|
||||
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (AlignTTSConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return AlignTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
Loading…
Reference in New Issue