Update glowtts docstrings and docs

pull/602/head
Eren Gölge 2021-06-30 14:30:55 +02:00
parent 21126839a8
commit 2e1a428b83
26 changed files with 305 additions and 225 deletions

View File

@ -2,7 +2,7 @@
Welcome to the 🐸TTS!
This repository is governed by the Contributor Covenant Code of Conduct - [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md).
## Where to start.
We welcome everyone who likes to contribute to 🐸TTS.

View File

@ -3,7 +3,7 @@
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
[![CircleCI](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)]()
[![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/)
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)

View File

@ -985,7 +985,7 @@ def get_last_checkpoint(path):
def process_args(args, config=None):
"""Process parsed comand line arguments.
"""Process parsed comand line arguments and initialize the config if not provided.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.

View File

@ -7,7 +7,7 @@ from TTS.tts.configs.shared_configs import BaseTTSConfig
class GlowTTSConfig(BaseTTSConfig):
"""Defines parameters for GlowTTS model.
Example:
Example:
>>> from TTS.tts.configs import GlowTTSConfig
>>> config = GlowTTSConfig()

View File

@ -12,7 +12,8 @@ def squeeze(x, x_mask=None, num_sqz=2):
Note:
each 's' is a n-dimensional vector.
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]"""
``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]``
"""
b, c, t = x.size()
t = (t // num_sqz) * num_sqz
@ -32,7 +33,8 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
Note:
each 's' is a n-dimensional vector.
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]"""
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
"""
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
@ -47,7 +49,10 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
class Decoder(nn.Module):
"""Stack of Glow Decoder Modules.
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
::
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args:
in_channels (int): channels of input tensor.
@ -106,6 +111,12 @@ class Decoder(nn.Module):
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1 ,T]`
- g: :math:`[B, C]`
"""
if not reverse:
flows = self.flows
logdet_tot = 0

View File

@ -6,13 +6,16 @@ from ..generic.normalization import LayerNorm
class DurationPredictor(nn.Module):
"""Glow-TTS duration prediction model.
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
Args:
in_channels ([type]): [description]
hidden_channels ([type]): [description]
kernel_size ([type]): [description]
dropout_p ([type]): [description]
::
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
Args:
in_channels (int): Number of channels of the input tensor.
hidden_channels (int): Number of hidden channels of the network.
kernel_size (int): Kernel size for the conv layers.
dropout_p (float): Dropout rate used after each conv layer.
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
@ -34,11 +37,8 @@ class DurationPredictor(nn.Module):
def forward(self, x, x_mask):
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
Returns:
[type]: [description]
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
x = self.conv_1(x * x_mask)
x = torch.relu(x)

View File

@ -15,13 +15,16 @@ from TTS.tts.utils.data import sequence_mask
class Encoder(nn.Module):
"""Glow-TTS encoder module.
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
|-> proj_var
|
|-> concat -> duration_predictor
speaker_embed
::
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
|-> proj_var
|
|-> concat -> duration_predictor
speaker_embed
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
@ -36,7 +39,8 @@ class Encoder(nn.Module):
Shapes:
- input: (B, T, C)
Notes:
::
suggested encoder params...
for encoder_type == 'rel_pos_transformer'
@ -139,9 +143,9 @@ class Encoder(nn.Module):
def forward(self, x, x_lengths, g=None):
"""
Shapes:
x: [B, C, T]
x_lengths: [B]
g (optional): [B, 1, T]
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B]`
- g (optional): :math:`[B, 1, T]`
"""
# embedding layer
# [B ,T, D]

View File

@ -10,21 +10,24 @@ from ..generic.normalization import LayerNorm
class ResidualConv1dLayerNormBlock(nn.Module):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
::
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of inner layer channels.
out_channels (int): number of output tensor channels.
kernel_size (int): kernel size of conv1d filter.
num_layers (int): number of blocks.
dropout_p (float): dropout rate for each block.
"""
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of inner layer channels.
out_channels (int): number of output tensor channels.
kernel_size (int): kernel size of conv1d filter.
num_layers (int): number of blocks.
dropout_p (float): dropout rate for each block.
"""
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
@ -51,6 +54,11 @@ class ResidualConv1dLayerNormBlock(nn.Module):
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
x_res = x
for i in range(self.num_layers):
x = self.conv_layers[i](x * x_mask)
@ -95,8 +103,8 @@ class InvConvNear(nn.Module):
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
"""
Shapes:
x: B x C x T
x_mask: B x 1 x T
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
b, c, t = x.size()
@ -139,10 +147,12 @@ class CouplingBlock(nn.Module):
"""Glow Affine Coupling block as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
::
Args:
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of hidden channels.
kernel_size (int): WaveNet filter kernel size.
@ -152,8 +162,8 @@ class CouplingBlock(nn.Module):
dropout_p (int): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
Note:
It does not use conditional inputs differently from WaveGlow.
Note:
It does not use the conditional inputs differently from WaveGlow.
"""
def __init__(
@ -193,9 +203,9 @@ class CouplingBlock(nn.Module):
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
"""
Shapes:
x: B x C x T
x_mask: B x 1 x T
g: B x C x 1
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if x_mask is None:
x_mask = 1

View File

@ -17,16 +17,18 @@ class RelativePositionMultiHeadAttention(nn.Module):
Note:
Example with relative attention window size 2
input = [a, b, c, d, e]
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
- input = [a, b, c, d, e]
- rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
Considering the input c
e(t-2) corresponds to c -> a
e(t-2) corresponds to c -> b
e(t-2) corresponds to c -> d
e(t-2) corresponds to c -> e
- e(t-2) corresponds to c -> a
- e(t-2) corresponds to c -> b
- e(t-2) corresponds to c -> d
- e(t-2) corresponds to c -> e
These embeddings are shared among different time steps. So input a, b, d and e also uses
the same embeddings.
@ -106,6 +108,12 @@ class RelativePositionMultiHeadAttention(nn.Module):
nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- c: :math:`[B, C, T]`
- attn_mask: :math:`[B, 1, T, T]`
"""
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
@ -163,9 +171,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
re (Tensor): relative value embedding vector. (a_(i,j)^V)
Shapes:
p_attn: [B, H, T, V]
re: [H or 1, V, D]
logits: [B, H, T, D]
-p_attn: :math:`[B, H, T, V]`
-re: :math:`[H or 1, V, D]`
-logits: :math:`[B, H, T, D]`
"""
logits = torch.matmul(p_attn, re.unsqueeze(0))
return logits
@ -178,9 +186,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
re (Tensor): relative key embedding vector. (a_(i,j)^K)
Shapes:
query: [B, H, T, D]
re: [H or 1, V, D]
logits: [B, H, T, V]
- query: :math:`[B, H, T, D]`
- re: :math:`[H or 1, V, D]`
- logits: :math:`[B, H, T, V]`
"""
# logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)])
logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1))
@ -202,10 +210,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
@staticmethod
def _relative_position_to_absolute_position(x):
"""Converts tensor from relative to absolute indexing for local attention.
Args:
x: [B, D, length, 2 * length - 1]
Shapes:
x: :math:`[B, C, T, 2 * T - 1]`
Returns:
A Tensor of shape [B, D, length, length]
A Tensor of shape :math:`[B, C, T, T]`
"""
batch, heads, length, _ = x.size()
# Pad to shift from relative to absolute indexing.
@ -220,8 +228,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
@staticmethod
def _absolute_position_to_relative_position(x):
"""
x: [B, H, T, T]
ret: [B, H, T, 2*T-1]
Shapes:
- x: :math:`[B, C, T, T]`
- ret: :math:`[B, C, T, 2*T-1]`
"""
batch, heads, length, _ = x.size()
# padd along column
@ -239,7 +248,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
Args:
length (int): an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
a Tensor with shape :math:`[1, 1, T, T]`
"""
# L
r = torch.arange(length, dtype=torch.float32)
@ -362,8 +371,8 @@ class RelativePositionTransformer(nn.Module):
def forward(self, x, x_mask):
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers):

View File

@ -30,24 +30,31 @@ class GlowTTS(BaseTTS):
the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
model can be easily extended to a multi-speaker setting.
Check `GlowTTSConfig` for class arguments.
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS(config)
"""
def __init__(self, config: GlowTTSConfig):
super().__init__()
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
self.init_multispeaker(config)
# pass all config fields to `self`
# for fewer code change
self.config = config
for key in config:
setattr(self, key, config[key])
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
self.init_multispeaker(config)
# if is a multispeaker and c_in_channels is 0, set to 256
self.c_in_channels = 0
if self.num_speakers > 1:
@ -91,7 +98,7 @@ class GlowTTS(BaseTTS):
@staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment
""" Compute and format the mode outputs with the given alignment map"""
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
@ -107,11 +114,11 @@ class GlowTTS(BaseTTS):
): # pylint: disable=dangerous-default-value
"""
Shapes:
x: [B, T]
x_lenghts: B
y: [B, T, C]
y_lengths: B
g: [B, C] or B
- x: :math:`[B, T]`
- x_lenghts::math:` B`
- y: :math:`[B, T, C]`
- y_lengths::math:` B`
- g: :math:`[B, C] or B`
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
@ -161,12 +168,13 @@ class GlowTTS(BaseTTS):
"""
It's similar to the teacher forcing in Tacotron.
It was proposed in: https://arxiv.org/abs/2104.05557
Shapes:
x: [B, T]
x_lenghts: B
y: [B, T, C]
y_lengths: B
g: [B, C] or B
- x: :math:`[B, T]`
- x_lenghts: :math:`B`
- y: :math:`[B, T, C]`
- y_lengths: :math:`B`
- g: :math:`[B, C] or B`
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
@ -221,9 +229,9 @@ class GlowTTS(BaseTTS):
): # pylint: disable=dangerous-default-value
"""
Shapes:
y: [B, T, C]
y_lengths: B
g: [B, C] or B
- y: :math:`[B, T, C]`
- y_lengths: :math:`B`
- g: :math:`[B, C] or B`
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)

View File

@ -54,7 +54,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
Tensor: spectrogram frames.
Shapes:
x: [B x T] or [B x 1 x T]
x: [B x T] or [:math:`[B, 1, T]`]
"""
if x.ndim == 2:
x = x.unsqueeze(1)

View File

@ -22,6 +22,9 @@ class GAN(BaseVocoder):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily.
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
is handled by the `GAN` class.
Args:
config (Coqpit): Model configuration.
@ -39,12 +42,41 @@ class GAN(BaseVocoder):
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.forward(x)
def inference(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's inference pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.inference(x)
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
network on the current pass.
Args:
batch (Dict): Batch of samples returned by the dataloader.
criterion (Dict): Criterion used to compute the losses.
optimizer_idx (int): ID of the optimizer in use on the current pass.
Raises:
ValueError: `optimizer_idx` is an unexpected value.
Returns:
Tuple[Dict, Dict]: model outputs and the computed loss values.
"""
outputs = None
loss_dict = None
@ -145,7 +177,18 @@ class GAN(BaseVocoder):
return outputs, loss_dict
@staticmethod
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
"""Logging shared by the training and evaluation.
Args:
name (str): Name of the run. `train` or `eval`,
ap (AudioProcessor): Audio processor used in training.
batch (Dict): Batch used in the last train/eval step.
outputs (Dict): Model outputs from the last train/eval step.
Returns:
Tuple[Dict, Dict]: log figures and audio samples.
"""
y_hat = outputs[0]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, name)
@ -154,13 +197,16 @@ class GAN(BaseVocoder):
return figures, audios
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
return self._log("train", ap, batch, outputs)
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
return self._log("eval", ap, batch, outputs)
def load_checkpoint(
@ -169,6 +215,13 @@ class GAN(BaseVocoder):
checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
) -> None:
"""Load a GAN checkpoint and initialize model parameters.
Args:
config (Coqpit): Model config.
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
@ -181,9 +234,21 @@ class GAN(BaseVocoder):
self.model_g.remove_weight_norm()
def on_train_step_start(self, trainer) -> None:
"""Enable the discriminator training based on `steps_to_start_discriminator`
Args:
trainer (Trainer): Trainer object.
"""
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
def get_optimizer(self):
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
)
@ -192,16 +257,37 @@ class GAN(BaseVocoder):
)
return [optimizer1, optimizer2]
def get_lr(self):
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer):
def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
@staticmethod
def format_batch(batch):
def format_batch(batch: List) -> Dict:
"""Format the batch for training.
Args:
batch (List): Batch out of the dataloader.
Returns:
Dict: formatted model inputs.
"""
if isinstance(batch[0], list):
x_G, y_G = batch[0]
x_D, y_D = batch[1]
@ -218,6 +304,19 @@ class GAN(BaseVocoder):
verbose: bool,
num_gpus: int,
):
"""Initiate and return the GAN dataloader.
Args:
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
is_eval (True): Set the dataloader for evaluation if true.
data_items (List): Data samples.
verbose (bool): Log information if true.
num_gpus (int): Number of GPUs in use.
Returns:
DataLoader: Torch dataloader.
"""
dataset = GANDataset(
ap=ap,
items=data_items,

View File

@ -34,7 +34,7 @@ class PQMF(tf.keras.layers.Layer):
def analysis(self, x):
"""
x : B x 1 x T
x : :math:`[B, 1, T]`
"""
x = tf.transpose(x, perm=[0, 2, 1])
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0)

View File

@ -92,7 +92,7 @@ class MelganGenerator(tf.keras.models.Model):
@tf.function(experimental_relax_shapes=True)
def call(self, c, training=False):
"""
c : B x C x T
c : :math:`[B, C, T]`
"""
if training:
raise NotImplementedError()

View File

@ -113,7 +113,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
"""
Sample from discretized mixture of logistic distributions
Args:
y (Tensor): B x C x T
y (Tensor): :math:`[B, C, T]`
log_scale_min (float): Log scale minimum value
Returns:
Tensor: sample in range of [-1, 1].

View File

@ -1,25 +0,0 @@
# AudioProcessor
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for
- Feature extraction.
- Sound normalization.
- Reading and writing audio files.
- Sampling audio signals.
- Normalizing and denormalizing audio signals.
- Griffin-Lim vocoder.
The `AudioProcessor` needs to be initialized with `TTS.config.shared_configs.BaseAudioConfig`. Any model config
also must inherit or initiate `BaseAudioConfig`.
## AudioProcessor
```{eval-rst}
.. autoclass:: TTS.utils.audio.AudioProcessor
:members:
```
## BaseAudioConfig
```{eval-rst}
.. autoclass:: TTS.config.shared_configs.BaseAudioConfig
:members:
```

View File

@ -50,6 +50,43 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
source_suffix = [".rst", ".md"]
# extensions
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.autosectionlabel',
'myst_parser',
"sphinx_copybutton",
"sphinx_inline_tabs",
]
# 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel',
# autosectionlabel throws warnings if section names are duplicated.
# The following tells autosectionlabel to not throw a warning for
# duplicated section names that are in different documents.
autosectionlabel_prefix_document = True
language = None
autodoc_inherit_docstrings = False
# Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none'
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
napoleon_custom_sections = [('Shapes', 'shape')]
# -- Options for HTML output -------------------------------------------------
@ -80,23 +117,3 @@ html_sidebars = {
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# using markdown
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.autosectionlabel',
'myst_parser',
"sphinx_copybutton",
"sphinx_inline_tabs",
]
# 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel',

View File

@ -1,4 +1,4 @@
# Converting Torch Tacotron to TF 2
# Converting Torch to TF 2
Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2.

View File

@ -1,25 +0,0 @@
# Datasets
## TTS Dataset
```{eval-rst}
.. autoclass:: TTS.tts.datasets.TTSDataset
:members:
```
## Vocoder Dataset
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.gan_dataset.GANDataset
:members:
```
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.wavegrad_dataset.WaveGradDataset
:members:
```
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.wavernn_dataset.WaveRNNDataset
:members:
```

View File

@ -105,7 +105,7 @@ The best approach is to pick a set of promising models and run a Mean-Opinion-Sc
- Check the 4th step under "How can I check model performance?"
## How can I test a trained model?
- The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <Synthesizing Speech>`.
- The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <synthesizing_speech>`.
- If you need to code your own ```TTS.utils.synthesizer.Synthesizer``` class.
## My Tacotron model does not stop - I see "Decoder stopped with 'max_decoder_steps" - Stopnet does not work.

View File

@ -36,7 +36,7 @@
There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you
the infinite flexibility to add custom behaviours for your model and training routines.
For more details, see {ref}`BaseTTS <Base TTS Model>` and `TTS/utils/callbacks.py`.
For more details, see {ref}`BaseTTS <Base TTS Model>` and :obj:`TTS.utils.callbacks`.
6. Optionally, define `MyModelArgs`.

View File

@ -2,7 +2,6 @@
```{include} ../../README.md
:relative-images:
```
----
# Documentation Content
@ -27,14 +26,28 @@
formatting_your_dataset
what_makes_a_good_dataset
tts_datasets
converting_torch_to_tf
.. toctree::
:maxdepth: 2
:caption: Main Classes
trainer_api
audio_processor
model_api
configuration
dataset
```
main_classes/trainer_api
main_classes/audio_processor
main_classes/model_api
main_classes/dataset
main_classes/gan
.. toctree::
:maxdepth: 2
:caption: `tts` Models
models/glow_tts.md
.. toctree::
:maxdepth: 2
:caption: `vocoder` Models
main_classes/gan
```

View File

@ -1,4 +1,4 @@
# AudioProcessor
# AudioProcessor API
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for

View File

@ -19,6 +19,6 @@ Model API provides you a set of functions that easily make your model compatible
## Base `vocoder` Model
```{eval-rst}
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder`
.. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder
:members:
```

View File

@ -1,24 +0,0 @@
# Model API
Model API provides you a set of functions that easily make your model compatible with the `Trainer`,
`Synthesizer` and `ModelZoo`.
## Base TTS Model
```{eval-rst}
.. autoclass:: TTS.model.BaseModel
:members:
```
## Base `tts` Model
```{eval-rst}
.. autoclass:: TTS.tts.models.base_tts.BaseTTS
:members:
```
## Base `vocoder` Model
```{eval-rst}
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder`
:members:
```

View File

@ -1,17 +0,0 @@
# Trainer API
The {class}`TTS.trainer.Trainer` provides a lightweight, extensible, and feature-complete training run-time. We optimized it for 🐸 but
can also be used for any DL training in different domains. It supports distributed multi-gpu, mixed-precision (apex or torch.amp) training.
## Trainer
```{eval-rst}
.. autoclass:: TTS.trainer.Trainer
:members:
```
## TrainingArgs
```{eval-rst}
.. autoclass:: TTS.trainer.TrainingArgs
:members:
```