mirror of https://github.com/coqui-ai/TTS.git
Update glowtts docstrings and docs
parent
21126839a8
commit
2e1a428b83
|
@ -2,7 +2,7 @@
|
|||
|
||||
Welcome to the 🐸TTS!
|
||||
|
||||
This repository is governed by the Contributor Covenant Code of Conduct - [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Where to start.
|
||||
We welcome everyone who likes to contribute to 🐸TTS.
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
|
||||
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
||||
|
||||
[![CircleCI](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)]()
|
||||
[![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions)
|
||||
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
|
||||
[![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/)
|
||||
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)
|
||||
|
|
|
@ -985,7 +985,7 @@ def get_last_checkpoint(path):
|
|||
|
||||
|
||||
def process_args(args, config=None):
|
||||
"""Process parsed comand line arguments.
|
||||
"""Process parsed comand line arguments and initialize the config if not provided.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
|
|
|
@ -7,7 +7,7 @@ from TTS.tts.configs.shared_configs import BaseTTSConfig
|
|||
class GlowTTSConfig(BaseTTSConfig):
|
||||
"""Defines parameters for GlowTTS model.
|
||||
|
||||
Example:
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs import GlowTTSConfig
|
||||
>>> config = GlowTTSConfig()
|
||||
|
|
|
@ -12,7 +12,8 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||
``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]``
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
|
||||
t = (t // num_sqz) * num_sqz
|
||||
|
@ -32,7 +33,8 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
|
||||
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
||||
|
@ -47,7 +49,10 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
class Decoder(nn.Module):
|
||||
"""Stack of Glow Decoder Modules.
|
||||
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
||||
|
||||
::
|
||||
|
||||
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
||||
|
||||
Args:
|
||||
in_channels (int): channels of input tensor.
|
||||
|
@ -106,6 +111,12 @@ class Decoder(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1 ,T]`
|
||||
- g: :math:`[B, C]`
|
||||
"""
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
logdet_tot = 0
|
||||
|
|
|
@ -6,13 +6,16 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Glow-TTS duration prediction model.
|
||||
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||
|
||||
Args:
|
||||
in_channels ([type]): [description]
|
||||
hidden_channels ([type]): [description]
|
||||
kernel_size ([type]): [description]
|
||||
dropout_p ([type]): [description]
|
||||
::
|
||||
|
||||
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels of the input tensor.
|
||||
hidden_channels (int): Number of hidden channels of the network.
|
||||
kernel_size (int): Kernel size for the conv layers.
|
||||
dropout_p (float): Dropout rate used after each conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||
|
@ -34,11 +37,8 @@ class DurationPredictor(nn.Module):
|
|||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
|
|
|
@ -15,13 +15,16 @@ from TTS.tts.utils.data import sequence_mask
|
|||
class Encoder(nn.Module):
|
||||
"""Glow-TTS encoder module.
|
||||
|
||||
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||||
|
|
||||
|-> proj_var
|
||||
|
|
||||
|-> concat -> duration_predictor
|
||||
↑
|
||||
speaker_embed
|
||||
::
|
||||
|
||||
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||||
|
|
||||
|-> proj_var
|
||||
|
|
||||
|-> concat -> duration_predictor
|
||||
↑
|
||||
speaker_embed
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
|
@ -36,7 +39,8 @@ class Encoder(nn.Module):
|
|||
Shapes:
|
||||
- input: (B, T, C)
|
||||
|
||||
Notes:
|
||||
::
|
||||
|
||||
suggested encoder params...
|
||||
|
||||
for encoder_type == 'rel_pos_transformer'
|
||||
|
@ -139,9 +143,9 @@ class Encoder(nn.Module):
|
|||
def forward(self, x, x_lengths, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_lengths: [B]
|
||||
g (optional): [B, 1, T]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- g (optional): :math:`[B, 1, T]`
|
||||
"""
|
||||
# embedding layer
|
||||
# [B ,T, D]
|
||||
|
|
|
@ -10,21 +10,24 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
|
||||
class ResidualConv1dLayerNormBlock(nn.Module):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
::
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of inner layer channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
kernel_size (int): kernel size of conv1d filter.
|
||||
num_layers (int): number of blocks.
|
||||
dropout_p (float): dropout rate for each block.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of inner layer channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
kernel_size (int): kernel size of conv1d filter.
|
||||
num_layers (int): number of blocks.
|
||||
dropout_p (float): dropout rate for each block.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
|
@ -51,6 +54,11 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
|||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
x_res = x
|
||||
for i in range(self.num_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
|
@ -95,8 +103,8 @@ class InvConvNear(nn.Module):
|
|||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
|
||||
b, c, t = x.size()
|
||||
|
@ -139,10 +147,12 @@ class CouplingBlock(nn.Module):
|
|||
"""Glow Affine Coupling block as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
::
|
||||
|
||||
Args:
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of hidden channels.
|
||||
kernel_size (int): WaveNet filter kernel size.
|
||||
|
@ -152,8 +162,8 @@ class CouplingBlock(nn.Module):
|
|||
dropout_p (int): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||
|
||||
Note:
|
||||
It does not use conditional inputs differently from WaveGlow.
|
||||
Note:
|
||||
It does not use the conditional inputs differently from WaveGlow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -193,9 +203,9 @@ class CouplingBlock(nn.Module):
|
|||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
g: B x C x 1
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
|
|
|
@ -17,16 +17,18 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
|
||||
Note:
|
||||
Example with relative attention window size 2
|
||||
input = [a, b, c, d, e]
|
||||
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||
|
||||
- input = [a, b, c, d, e]
|
||||
- rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||
|
||||
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
|
||||
|
||||
Considering the input c
|
||||
e(t-2) corresponds to c -> a
|
||||
e(t-2) corresponds to c -> b
|
||||
e(t-2) corresponds to c -> d
|
||||
e(t-2) corresponds to c -> e
|
||||
|
||||
- e(t-2) corresponds to c -> a
|
||||
- e(t-2) corresponds to c -> b
|
||||
- e(t-2) corresponds to c -> d
|
||||
- e(t-2) corresponds to c -> e
|
||||
|
||||
These embeddings are shared among different time steps. So input a, b, d and e also uses
|
||||
the same embeddings.
|
||||
|
@ -106,6 +108,12 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- c: :math:`[B, C, T]`
|
||||
- attn_mask: :math:`[B, 1, T, T]`
|
||||
"""
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
@ -163,9 +171,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
re (Tensor): relative value embedding vector. (a_(i,j)^V)
|
||||
|
||||
Shapes:
|
||||
p_attn: [B, H, T, V]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, D]
|
||||
-p_attn: :math:`[B, H, T, V]`
|
||||
-re: :math:`[H or 1, V, D]`
|
||||
-logits: :math:`[B, H, T, D]`
|
||||
"""
|
||||
logits = torch.matmul(p_attn, re.unsqueeze(0))
|
||||
return logits
|
||||
|
@ -178,9 +186,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
re (Tensor): relative key embedding vector. (a_(i,j)^K)
|
||||
|
||||
Shapes:
|
||||
query: [B, H, T, D]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, V]
|
||||
- query: :math:`[B, H, T, D]`
|
||||
- re: :math:`[H or 1, V, D]`
|
||||
- logits: :math:`[B, H, T, V]`
|
||||
"""
|
||||
# logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)])
|
||||
logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1))
|
||||
|
@ -202,10 +210,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
@staticmethod
|
||||
def _relative_position_to_absolute_position(x):
|
||||
"""Converts tensor from relative to absolute indexing for local attention.
|
||||
Args:
|
||||
x: [B, D, length, 2 * length - 1]
|
||||
Shapes:
|
||||
x: :math:`[B, C, T, 2 * T - 1]`
|
||||
Returns:
|
||||
A Tensor of shape [B, D, length, length]
|
||||
A Tensor of shape :math:`[B, C, T, T]`
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Pad to shift from relative to absolute indexing.
|
||||
|
@ -220,8 +228,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
@staticmethod
|
||||
def _absolute_position_to_relative_position(x):
|
||||
"""
|
||||
x: [B, H, T, T]
|
||||
ret: [B, H, T, 2*T-1]
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T, T]`
|
||||
- ret: :math:`[B, C, T, 2*T-1]`
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
|
@ -239,7 +248,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
Args:
|
||||
length (int): an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
a Tensor with shape :math:`[1, 1, T, T]`
|
||||
"""
|
||||
# L
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
|
@ -362,8 +371,8 @@ class RelativePositionTransformer(nn.Module):
|
|||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.num_layers):
|
||||
|
|
|
@ -30,24 +30,31 @@ class GlowTTS(BaseTTS):
|
|||
the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
|
||||
model can be easily extended to a multi-speaker setting.
|
||||
|
||||
Check `GlowTTSConfig` for class arguments.
|
||||
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig()
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: GlowTTSConfig):
|
||||
|
||||
super().__init__()
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
self.config = config
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
self.c_in_channels = 0
|
||||
if self.num_speakers > 1:
|
||||
|
@ -91,7 +98,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
# compute final values with the computed alignment
|
||||
""" Compute and format the mode outputs with the given alignment map"""
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
@ -107,11 +114,11 @@ class GlowTTS(BaseTTS):
|
|||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, T, C]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts::math:` B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths::math:` B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
|
@ -161,12 +168,13 @@ class GlowTTS(BaseTTS):
|
|||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, T, C]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts: :math:`B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths: :math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
|
@ -221,9 +229,9 @@ class GlowTTS(BaseTTS):
|
|||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
y: [B, T, C]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths: :math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
|
|
|
@ -54,7 +54,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [B x 1 x T]
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
|
|
|
@ -22,6 +22,9 @@ class GAN(BaseVocoder):
|
|||
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
|
||||
It also helps mixing and matching different generator and disciminator networks easily.
|
||||
|
||||
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
|
||||
is handled by the `GAN` class.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
||||
|
@ -39,12 +42,41 @@ class GAN(BaseVocoder):
|
|||
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's forward pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.forward(x)
|
||||
|
||||
def inference(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the generator's inference pass.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
Returns:
|
||||
torch.Tensor: output of the GAN generator network.
|
||||
"""
|
||||
return self.model_g.inference(x)
|
||||
|
||||
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
|
||||
network on the current pass.
|
||||
|
||||
Args:
|
||||
batch (Dict): Batch of samples returned by the dataloader.
|
||||
criterion (Dict): Criterion used to compute the losses.
|
||||
optimizer_idx (int): ID of the optimizer in use on the current pass.
|
||||
|
||||
Raises:
|
||||
ValueError: `optimizer_idx` is an unexpected value.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: model outputs and the computed loss values.
|
||||
"""
|
||||
outputs = None
|
||||
loss_dict = None
|
||||
|
||||
|
@ -145,7 +177,18 @@ class GAN(BaseVocoder):
|
|||
return outputs, loss_dict
|
||||
|
||||
@staticmethod
|
||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
name (str): Name of the run. `train` or `eval`,
|
||||
ap (AudioProcessor): Audio processor used in training.
|
||||
batch (Dict): Batch used in the last train/eval step.
|
||||
outputs (Dict): Model outputs from the last train/eval step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
|
@ -154,13 +197,16 @@ class GAN(BaseVocoder):
|
|||
return figures, audios
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for training."""
|
||||
return self._log("train", ap, batch, outputs)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Call `train_step()` with `no_grad()`"""
|
||||
return self.train_step(batch, criterion, optimizer_idx)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""Call `_log()` for evaluation."""
|
||||
return self._log("eval", ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(
|
||||
|
@ -169,6 +215,13 @@ class GAN(BaseVocoder):
|
|||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
) -> None:
|
||||
"""Load a GAN checkpoint and initialize model parameters.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
|
@ -181,9 +234,21 @@ class GAN(BaseVocoder):
|
|||
self.model_g.remove_weight_norm()
|
||||
|
||||
def on_train_step_start(self, trainer) -> None:
|
||||
"""Enable the discriminator training based on `steps_to_start_discriminator`
|
||||
|
||||
Args:
|
||||
trainer (Trainer): Trainer object.
|
||||
"""
|
||||
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
|
||||
|
||||
def get_optimizer(self):
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
||||
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
||||
|
||||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
optimizer1 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
|
||||
)
|
||||
|
@ -192,16 +257,37 @@ class GAN(BaseVocoder):
|
|||
)
|
||||
return [optimizer1, optimizer2]
|
||||
|
||||
def get_lr(self):
|
||||
def get_lr(self) -> List:
|
||||
"""Set the initial learning rates for each optimizer.
|
||||
|
||||
Returns:
|
||||
List: learning rates for each optimizer.
|
||||
"""
|
||||
return [self.config.lr_gen, self.config.lr_disc]
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the schedulers for each optimizer.
|
||||
|
||||
Args:
|
||||
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
||||
|
||||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler1, scheduler2]
|
||||
|
||||
@staticmethod
|
||||
def format_batch(batch):
|
||||
def format_batch(batch: List) -> Dict:
|
||||
"""Format the batch for training.
|
||||
|
||||
Args:
|
||||
batch (List): Batch out of the dataloader.
|
||||
|
||||
Returns:
|
||||
Dict: formatted model inputs.
|
||||
"""
|
||||
if isinstance(batch[0], list):
|
||||
x_G, y_G = batch[0]
|
||||
x_D, y_D = batch[1]
|
||||
|
@ -218,6 +304,19 @@ class GAN(BaseVocoder):
|
|||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
"""Initiate and return the GAN dataloader.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
is_eval (True): Set the dataloader for evaluation if true.
|
||||
data_items (List): Data samples.
|
||||
verbose (bool): Log information if true.
|
||||
num_gpus (int): Number of GPUs in use.
|
||||
|
||||
Returns:
|
||||
DataLoader: Torch dataloader.
|
||||
"""
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
|
|
|
@ -34,7 +34,7 @@ class PQMF(tf.keras.layers.Layer):
|
|||
|
||||
def analysis(self, x):
|
||||
"""
|
||||
x : B x 1 x T
|
||||
x : :math:`[B, 1, T]`
|
||||
"""
|
||||
x = tf.transpose(x, perm=[0, 2, 1])
|
||||
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0)
|
||||
|
|
|
@ -92,7 +92,7 @@ class MelganGenerator(tf.keras.models.Model):
|
|||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, c, training=False):
|
||||
"""
|
||||
c : B x C x T
|
||||
c : :math:`[B, C, T]`
|
||||
"""
|
||||
if training:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -113,7 +113,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
|||
"""
|
||||
Sample from discretized mixture of logistic distributions
|
||||
Args:
|
||||
y (Tensor): B x C x T
|
||||
y (Tensor): :math:`[B, C, T]`
|
||||
log_scale_min (float): Log scale minimum value
|
||||
Returns:
|
||||
Tensor: sample in range of [-1, 1].
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# AudioProcessor
|
||||
|
||||
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for
|
||||
|
||||
- Feature extraction.
|
||||
- Sound normalization.
|
||||
- Reading and writing audio files.
|
||||
- Sampling audio signals.
|
||||
- Normalizing and denormalizing audio signals.
|
||||
- Griffin-Lim vocoder.
|
||||
|
||||
The `AudioProcessor` needs to be initialized with `TTS.config.shared_configs.BaseAudioConfig`. Any model config
|
||||
also must inherit or initiate `BaseAudioConfig`.
|
||||
|
||||
## AudioProcessor
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.utils.audio.AudioProcessor
|
||||
:members:
|
||||
```
|
||||
|
||||
## BaseAudioConfig
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.config.shared_configs.BaseAudioConfig
|
||||
:members:
|
||||
```
|
|
@ -50,6 +50,43 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
|
|||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# extensions
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'myst_parser',
|
||||
"sphinx_copybutton",
|
||||
"sphinx_inline_tabs",
|
||||
]
|
||||
|
||||
# 'sphinxcontrib.katex',
|
||||
# 'sphinx.ext.autosectionlabel',
|
||||
|
||||
|
||||
# autosectionlabel throws warnings if section names are duplicated.
|
||||
# The following tells autosectionlabel to not throw a warning for
|
||||
# duplicated section names that are in different documents.
|
||||
autosectionlabel_prefix_document = True
|
||||
|
||||
language = None
|
||||
|
||||
autodoc_inherit_docstrings = False
|
||||
|
||||
# Disable displaying type annotations, these can be very verbose
|
||||
autodoc_typehints = 'none'
|
||||
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
napoleon_custom_sections = [('Shapes', 'shape')]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
|
@ -80,23 +117,3 @@ html_sidebars = {
|
|||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
|
||||
# using markdown
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.doctest',
|
||||
'sphinx.ext.intersphinx',
|
||||
'sphinx.ext.todo',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.autosectionlabel',
|
||||
'myst_parser',
|
||||
"sphinx_copybutton",
|
||||
"sphinx_inline_tabs",
|
||||
]
|
||||
|
||||
# 'sphinxcontrib.katex',
|
||||
# 'sphinx.ext.autosectionlabel',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Converting Torch Tacotron to TF 2
|
||||
# Converting Torch to TF 2
|
||||
|
||||
Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2.
|
||||
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
# Datasets
|
||||
|
||||
## TTS Dataset
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.datasets.TTSDataset
|
||||
:members:
|
||||
```
|
||||
|
||||
## Vocoder Dataset
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.vocoder.datasets.gan_dataset.GANDataset
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.vocoder.datasets.wavegrad_dataset.WaveGradDataset
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.vocoder.datasets.wavernn_dataset.WaveRNNDataset
|
||||
:members:
|
||||
```
|
|
@ -105,7 +105,7 @@ The best approach is to pick a set of promising models and run a Mean-Opinion-Sc
|
|||
- Check the 4th step under "How can I check model performance?"
|
||||
|
||||
## How can I test a trained model?
|
||||
- The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <Synthesizing Speech>`.
|
||||
- The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <synthesizing_speech>`.
|
||||
- If you need to code your own ```TTS.utils.synthesizer.Synthesizer``` class.
|
||||
|
||||
## My Tacotron model does not stop - I see "Decoder stopped with 'max_decoder_steps" - Stopnet does not work.
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you
|
||||
the infinite flexibility to add custom behaviours for your model and training routines.
|
||||
|
||||
For more details, see {ref}`BaseTTS <Base TTS Model>` and `TTS/utils/callbacks.py`.
|
||||
For more details, see {ref}`BaseTTS <Base TTS Model>` and :obj:`TTS.utils.callbacks`.
|
||||
|
||||
6. Optionally, define `MyModelArgs`.
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
```{include} ../../README.md
|
||||
:relative-images:
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
# Documentation Content
|
||||
|
@ -27,14 +26,28 @@
|
|||
formatting_your_dataset
|
||||
what_makes_a_good_dataset
|
||||
tts_datasets
|
||||
converting_torch_to_tf
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Main Classes
|
||||
|
||||
trainer_api
|
||||
audio_processor
|
||||
model_api
|
||||
configuration
|
||||
dataset
|
||||
```
|
||||
main_classes/trainer_api
|
||||
main_classes/audio_processor
|
||||
main_classes/model_api
|
||||
main_classes/dataset
|
||||
main_classes/gan
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: `tts` Models
|
||||
|
||||
models/glow_tts.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: `vocoder` Models
|
||||
|
||||
main_classes/gan
|
||||
```
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# AudioProcessor
|
||||
# AudioProcessor API
|
||||
|
||||
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for
|
||||
|
||||
|
|
|
@ -19,6 +19,6 @@ Model API provides you a set of functions that easily make your model compatible
|
|||
## Base `vocoder` Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder`
|
||||
.. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder
|
||||
:members:
|
||||
```
|
|
@ -1,24 +0,0 @@
|
|||
# Model API
|
||||
Model API provides you a set of functions that easily make your model compatible with the `Trainer`,
|
||||
`Synthesizer` and `ModelZoo`.
|
||||
|
||||
## Base TTS Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.model.BaseModel
|
||||
:members:
|
||||
```
|
||||
|
||||
## Base `tts` Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.base_tts.BaseTTS
|
||||
:members:
|
||||
```
|
||||
|
||||
## Base `vocoder` Model
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder`
|
||||
:members:
|
||||
```
|
|
@ -1,17 +0,0 @@
|
|||
# Trainer API
|
||||
|
||||
The {class}`TTS.trainer.Trainer` provides a lightweight, extensible, and feature-complete training run-time. We optimized it for 🐸 but
|
||||
can also be used for any DL training in different domains. It supports distributed multi-gpu, mixed-precision (apex or torch.amp) training.
|
||||
|
||||
|
||||
## Trainer
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.trainer.Trainer
|
||||
:members:
|
||||
```
|
||||
|
||||
## TrainingArgs
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.trainer.TrainingArgs
|
||||
:members:
|
||||
```
|
Loading…
Reference in New Issue