mirror of https://github.com/coqui-ai/TTS.git
Update VITS for the new API
parent
f802a931a3
commit
ea965a5683
|
@ -1,7 +1,8 @@
|
|||
import math
|
||||
from dataclasses import dataclass, field
|
||||
import random
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -10,6 +11,7 @@ from torch import nn
|
|||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
|
@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
|
|||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -283,91 +286,79 @@ class Vits(BaseTTS):
|
|||
self.END2END = True
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
if config.__class__.__name__ == "VitsConfig":
|
||||
# loading from VitsConfig
|
||||
self.num_chars = self.tokenizer.characters.num_chars
|
||||
self.config = config
|
||||
args = self.config.model_args
|
||||
elif isinstance(config, VitsArgs):
|
||||
# loading from VitsArgs
|
||||
self.config = config
|
||||
args = config
|
||||
else:
|
||||
raise ValueError("config must be either a VitsConfig or VitsArgs")
|
||||
|
||||
self.args = args
|
||||
|
||||
self.init_multispeaker(config)
|
||||
self.init_multilingual(config)
|
||||
|
||||
self.length_scale = args.length_scale
|
||||
self.noise_scale = args.noise_scale
|
||||
self.inference_noise_scale = args.inference_noise_scale
|
||||
self.inference_noise_scale_dp = args.inference_noise_scale_dp
|
||||
self.noise_scale_dp = args.noise_scale_dp
|
||||
self.max_inference_len = args.max_inference_len
|
||||
self.spec_segment_size = args.spec_segment_size
|
||||
self.length_scale = self.args.length_scale
|
||||
self.noise_scale = self.args.noise_scale
|
||||
self.inference_noise_scale = self.args.inference_noise_scale
|
||||
self.inference_noise_scale_dp = self.args.inference_noise_scale_dp
|
||||
self.noise_scale_dp = self.args.noise_scale_dp
|
||||
self.max_inference_len = self.args.max_inference_len
|
||||
self.spec_segment_size = self.args.spec_segment_size
|
||||
|
||||
self.text_encoder = TextEncoder(
|
||||
args.num_chars,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels_ffn_text_encoder,
|
||||
args.num_heads_text_encoder,
|
||||
args.num_layers_text_encoder,
|
||||
args.kernel_size_text_encoder,
|
||||
args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
self.args.num_chars,
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels_ffn_text_encoder,
|
||||
self.args.num_heads_text_encoder,
|
||||
self.args.num_layers_text_encoder,
|
||||
self.args.kernel_size_text_encoder,
|
||||
self.args.dropout_p_text_encoder,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
args.out_channels,
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
kernel_size=args.kernel_size_posterior_encoder,
|
||||
dilation_rate=args.dilation_rate_posterior_encoder,
|
||||
num_layers=args.num_layers_posterior_encoder,
|
||||
self.args.out_channels,
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.kernel_size_posterior_encoder,
|
||||
dilation_rate=self.args.dilation_rate_posterior_encoder,
|
||||
num_layers=self.args.num_layers_posterior_encoder,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
self.flow = ResidualCouplingBlocks(
|
||||
args.hidden_channels,
|
||||
args.hidden_channels,
|
||||
kernel_size=args.kernel_size_flow,
|
||||
dilation_rate=args.dilation_rate_flow,
|
||||
num_layers=args.num_layers_flow,
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.kernel_size_flow,
|
||||
dilation_rate=self.args.dilation_rate_flow,
|
||||
num_layers=self.args.num_layers_flow,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
)
|
||||
|
||||
if args.use_sdp:
|
||||
if self.args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
192,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
self.args.dropout_p_duration_predictor,
|
||||
4,
|
||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
256,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||
self.args.dropout_p_duration_predictor,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
self.waveform_decoder = HifiganGenerator(
|
||||
args.hidden_channels,
|
||||
self.args.hidden_channels,
|
||||
1,
|
||||
args.resblock_type_decoder,
|
||||
args.resblock_dilation_sizes_decoder,
|
||||
args.resblock_kernel_sizes_decoder,
|
||||
args.upsample_kernel_sizes_decoder,
|
||||
args.upsample_initial_channel_decoder,
|
||||
args.upsample_rates_decoder,
|
||||
self.args.resblock_type_decoder,
|
||||
self.args.resblock_dilation_sizes_decoder,
|
||||
self.args.resblock_kernel_sizes_decoder,
|
||||
self.args.upsample_kernel_sizes_decoder,
|
||||
self.args.upsample_initial_channel_decoder,
|
||||
self.args.upsample_rates_decoder,
|
||||
inference_padding=0,
|
||||
cond_channels=self.embedded_speaker_dim,
|
||||
conv_pre_weight_norm=False,
|
||||
|
@ -375,8 +366,8 @@ class Vits(BaseTTS):
|
|||
conv_post_bias=False,
|
||||
)
|
||||
|
||||
if args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
|
||||
if self.args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
|
@ -883,19 +874,17 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
self._log(ap, batch, outputs, "train")
|
||||
self._log(self.ap, batch, outputs, "train")
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
return self._log(ap, batch, outputs, "eval")
|
||||
return self._log(self.ap, batch, outputs, "eval")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
def test_run(self) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -990,36 +979,6 @@ class Vits(BaseTTS):
|
|||
|
||||
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
|
||||
|
||||
@staticmethod
|
||||
def make_symbols(config):
|
||||
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
|
||||
whole training and inference steps."""
|
||||
_pad = config.characters["pad"]
|
||||
_punctuations = config.characters["punctuations"]
|
||||
_letters = config.characters["characters"]
|
||||
_letters_ipa = config.characters["phonemes"]
|
||||
symbols = [_pad] + list(_punctuations) + list(_letters)
|
||||
if config.use_phonemes:
|
||||
symbols += list(_letters_ipa)
|
||||
return symbols
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit):
|
||||
if config.characters is not None:
|
||||
symbols = Vits.make_symbols(config)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||
parse_symbols,
|
||||
phonemes,
|
||||
symbols,
|
||||
)
|
||||
|
||||
config.characters = parse_symbols()
|
||||
if config.use_phonemes:
|
||||
symbols = phonemes
|
||||
num_chars = len(symbols) + getattr(config, "add_blank", False)
|
||||
return symbols, config, num_chars
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
|
@ -1035,23 +994,65 @@ class Vits(BaseTTS):
|
|||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit"):
|
||||
"""Initialize model from config."""
|
||||
|
||||
# init characters
|
||||
if config.use_phonemes:
|
||||
from TTS.tts.utils.text.characters import IPAPhonemes
|
||||
|
||||
characters = IPAPhonemes().init_from_config(config)
|
||||
else:
|
||||
from TTS.tts.utils.text.characters import Graphemes
|
||||
|
||||
characters = Graphemes().init_from_config(config)
|
||||
config.num_chars = characters.num_chars
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (VitsConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return Vits(config, ap, tokenizer, speaker_manager)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager)
|
||||
|
||||
|
||||
class VitsCharacters(BaseCharacters):
|
||||
"""Characters class for VITs model for compatibility with pre-trained models"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graphemes: str = _characters,
|
||||
punctuations: str = _punctuations,
|
||||
pad: str = _pad,
|
||||
ipa_characters: str = _phonemes,
|
||||
) -> None:
|
||||
if ipa_characters is not None:
|
||||
graphemes += ipa_characters
|
||||
super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
|
||||
|
||||
def _create_vocab(self):
|
||||
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
# pylint: disable=unnecessary-comprehension
|
||||
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
if config.characters is not None:
|
||||
_pad = config.characters["pad"]
|
||||
_punctuations = config.characters["punctuations"]
|
||||
_letters = config.characters["characters"]
|
||||
_letters_ipa = config.characters["phonemes"]
|
||||
return (
|
||||
VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
|
||||
config,
|
||||
)
|
||||
characters = VitsCharacters()
|
||||
new_config = replace(config, characters=characters.to_config())
|
||||
return characters, new_config
|
||||
|
||||
def to_config(self) -> "CharactersConfig":
|
||||
return CharactersConfig(
|
||||
characters=self._characters,
|
||||
punctuations=self._punctuations,
|
||||
pad=self._pad,
|
||||
eos=None,
|
||||
bos=None,
|
||||
blank=self._blank,
|
||||
is_unique=False,
|
||||
is_sorted=True,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue