Update VITS for the new API

pull/1324/head
Eren Gölge 2021-12-07 12:55:18 +00:00
parent f802a931a3
commit ea965a5683
1 changed files with 107 additions and 106 deletions

View File

@ -1,7 +1,8 @@
import math
from dataclasses import dataclass, field
import random
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import torch
import torchaudio
@ -10,6 +11,7 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
@ -19,6 +21,7 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
@ -283,91 +286,79 @@ class Vits(BaseTTS):
self.END2END = True
self.speaker_manager = speaker_manager
self.language_manager = language_manager
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
self.num_chars = self.tokenizer.characters.num_chars
self.config = config
args = self.config.model_args
elif isinstance(config, VitsArgs):
# loading from VitsArgs
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or VitsArgs")
self.args = args
self.init_multispeaker(config)
self.init_multilingual(config)
self.length_scale = args.length_scale
self.noise_scale = args.noise_scale
self.inference_noise_scale = args.inference_noise_scale
self.inference_noise_scale_dp = args.inference_noise_scale_dp
self.noise_scale_dp = args.noise_scale_dp
self.max_inference_len = args.max_inference_len
self.spec_segment_size = args.spec_segment_size
self.length_scale = self.args.length_scale
self.noise_scale = self.args.noise_scale
self.inference_noise_scale = self.args.inference_noise_scale
self.inference_noise_scale_dp = self.args.inference_noise_scale_dp
self.noise_scale_dp = self.args.noise_scale_dp
self.max_inference_len = self.args.max_inference_len
self.spec_segment_size = self.args.spec_segment_size
self.text_encoder = TextEncoder(
args.num_chars,
args.hidden_channels,
args.hidden_channels,
args.hidden_channels_ffn_text_encoder,
args.num_heads_text_encoder,
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim,
self.args.num_chars,
self.args.hidden_channels,
self.args.hidden_channels,
self.args.hidden_channels_ffn_text_encoder,
self.args.num_heads_text_encoder,
self.args.num_layers_text_encoder,
self.args.kernel_size_text_encoder,
self.args.dropout_p_text_encoder,
)
self.posterior_encoder = PosteriorEncoder(
args.out_channels,
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_posterior_encoder,
dilation_rate=args.dilation_rate_posterior_encoder,
num_layers=args.num_layers_posterior_encoder,
self.args.out_channels,
self.args.hidden_channels,
self.args.hidden_channels,
kernel_size=self.args.kernel_size_posterior_encoder,
dilation_rate=self.args.dilation_rate_posterior_encoder,
num_layers=self.args.num_layers_posterior_encoder,
cond_channels=self.embedded_speaker_dim,
)
self.flow = ResidualCouplingBlocks(
args.hidden_channels,
args.hidden_channels,
kernel_size=args.kernel_size_flow,
dilation_rate=args.dilation_rate_flow,
num_layers=args.num_layers_flow,
self.args.hidden_channels,
self.args.hidden_channels,
kernel_size=self.args.kernel_size_flow,
dilation_rate=self.args.dilation_rate_flow,
num_layers=self.args.num_layers_flow,
cond_channels=self.embedded_speaker_dim,
)
if args.use_sdp:
if self.args.use_sdp:
self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels,
self.args.hidden_channels,
192,
3,
args.dropout_p_duration_predictor,
self.args.dropout_p_duration_predictor,
4,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
language_emb_dim=self.embedded_language_dim,
)
else:
self.duration_predictor = DurationPredictor(
args.hidden_channels,
self.args.hidden_channels,
256,
3,
args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
self.args.dropout_p_duration_predictor,
cond_channels=self.embedded_speaker_dim,
language_emb_dim=self.embedded_language_dim,
)
self.waveform_decoder = HifiganGenerator(
args.hidden_channels,
self.args.hidden_channels,
1,
args.resblock_type_decoder,
args.resblock_dilation_sizes_decoder,
args.resblock_kernel_sizes_decoder,
args.upsample_kernel_sizes_decoder,
args.upsample_initial_channel_decoder,
args.upsample_rates_decoder,
self.args.resblock_type_decoder,
self.args.resblock_dilation_sizes_decoder,
self.args.resblock_kernel_sizes_decoder,
self.args.upsample_kernel_sizes_decoder,
self.args.upsample_initial_channel_decoder,
self.args.upsample_rates_decoder,
inference_padding=0,
cond_channels=self.embedded_speaker_dim,
conv_pre_weight_norm=False,
@ -375,8 +366,8 @@ class Vits(BaseTTS):
conv_post_bias=False,
)
if args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
if self.args.init_discriminator:
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
@ -883,19 +874,17 @@ class Vits(BaseTTS):
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
"""
ap = assets["audio_processor"]
self._log(ap, batch, outputs, "train")
self._log(self.ap, batch, outputs, "train")
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
return self._log(ap, batch, outputs, "eval")
return self._log(self.ap, batch, outputs, "eval")
@torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]:
def test_run(self) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@ -990,36 +979,6 @@ class Vits(BaseTTS):
return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
@staticmethod
def make_symbols(config):
"""Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
whole training and inference steps."""
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
symbols = [_pad] + list(_punctuations) + list(_letters)
if config.use_phonemes:
symbols += list(_letters_ipa)
return symbols
@staticmethod
def get_characters(config: Coqpit):
if config.characters is not None:
symbols = Vits.make_symbols(config)
else:
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
parse_symbols,
phonemes,
symbols,
)
config.characters = parse_symbols()
if config.use_phonemes:
symbols = phonemes
num_chars = len(symbols) + getattr(config, "add_blank", False)
return symbols, config, num_chars
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
@ -1035,23 +994,65 @@ class Vits(BaseTTS):
assert not self.training
@staticmethod
def init_from_config(config: "Coqpit"):
"""Initialize model from config."""
# init characters
if config.use_phonemes:
from TTS.tts.utils.text.characters import IPAPhonemes
characters = IPAPhonemes().init_from_config(config)
else:
from TTS.tts.utils.text.characters import Graphemes
characters = Graphemes().init_from_config(config)
config.num_chars = characters.num_chars
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return Vits(config, ap, tokenizer, speaker_manager)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return Vits(new_config, ap, tokenizer, speaker_manager)
class VitsCharacters(BaseCharacters):
"""Characters class for VITs model for compatibility with pre-trained models"""
def __init__(
self,
graphemes: str = _characters,
punctuations: str = _punctuations,
pad: str = _pad,
ipa_characters: str = _phonemes,
) -> None:
if ipa_characters is not None:
graphemes += ipa_characters
super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
def _create_vocab(self):
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
# pylint: disable=unnecessary-comprehension
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
@staticmethod
def init_from_config(config: Coqpit):
if config.characters is not None:
_pad = config.characters["pad"]
_punctuations = config.characters["punctuations"]
_letters = config.characters["characters"]
_letters_ipa = config.characters["phonemes"]
return (
VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
config,
)
characters = VitsCharacters()
new_config = replace(config, characters=characters.to_config())
return characters, new_config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
characters=self._characters,
punctuations=self._punctuations,
pad=self._pad,
eos=None,
bos=None,
blank=self._blank,
is_unique=False,
is_sorted=True,
)