Tortoise TTS inference (#2547)

* initial commit

* Tortoise inference

* revert path change

* style fix

* remove accidental remove

* style fixes

* style fixes

* removed unwanted assests and deps

* remove changes

* remove cvvp

* style fix black

* added tortoise config and updated config and args, refactoring the code

* added tortoise to api

* Pull mel_norm from url

* Use TTS cleaners

* Let download model files

* add ability to pass tortoise presets through coqui api

* fix tests

* fix style and tests

* fix tts commandline for tortoise

* Add config.json to tortoise

* Use kwargs

* Use regular model api for loading tortoise

* Add load from dir to synthesizer

* Fix Tortoise floats

* Use model_dir when there are multiple urls

* Use `synthesize` when exists

* lint fixes and resolve preset bug

* resolve a download bug and update model link

* fix json

* do tortoise inference from voice dir

* fix

* fix test

* fix speaker id and remove assests

* update inference_tests.yml

* replace inference_test.yml

* fix extra dir as None

* fix tests

* remove space

* Reformat docstring

* Add docs

* Update docs

* lint fixes

---------

Co-authored-by: Eren Gölge <egolge@coqui.ai>
Co-authored-by: Eren Gölge <erogol@hotmail.com>
pull/2616/head^2
manmay nakhashi 2023-05-16 04:28:21 +05:30 committed by GitHub
parent 0b6b957e76
commit a3d5801c44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 8298 additions and 35 deletions

View File

@ -52,4 +52,4 @@ jobs:
- name: Unit tests
run: make inference_tests
env:
COQUI_STUDIO_TOKEN: ${{ secrets.COQUI_STUDIO_TOKEN }}
COQUI_STUDIO_TOKEN: ${{ secrets.COQUI_STUDIO_TOKEN }}

View File

@ -220,6 +220,26 @@
"license": "apache 2.0",
"contact": "adamfroghyar@gmail.com"
}
},
"multi-dataset":{
"tortoise-v2":{
"description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts",
"github_rls_url": ["https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_auto.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_diffuser.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/vocoder.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth",
"https://coqui.gateway.scarf.sh/v0.14.1_models/config.json"
],
"commit": "c1875f6",
"default_vocoder": null,
"author": "@neonbjb - James Betker, @manmay-nakhashi Manmay Nakhashi",
"license": "apache 2.0"
}
},
"jenny": {
"jenny":{

View File

@ -342,10 +342,14 @@ class TTS:
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if isinstance(model_item["github_rls_url"], list):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None
return model_path, config_path, None, None, None
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path
return model_path, config_path, vocoder_path, vocoder_config_path, None
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the voice conversion models by name.
@ -355,7 +359,7 @@ class TTS:
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _ = self.download_model_by_name(model_name)
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
@ -374,7 +378,9 @@ class TTS:
if "coqui_studio" in model_name:
self.csapi = CS_API()
else:
model_path, config_path, vocoder_path, vocoder_config_path = self.download_model_by_name(model_name)
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name
)
# init synthesizer
# None values are fetch from the model
@ -387,6 +393,7 @@ class TTS:
vocoder_config=vocoder_config_path,
encoder_checkpoint=None,
encoder_config=None,
model_dir=model_dir,
use_cuda=gpu,
)
@ -422,6 +429,7 @@ class TTS:
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
**kwargs,
) -> None:
"""Check if the arguments are valid for the model."""
if not self.is_coqui_studio:
@ -430,7 +438,7 @@ class TTS:
raise ValueError("Model is multi-speaker but no `speaker` is provided.")
if self.is_multi_lingual and language is None:
raise ValueError("Model is multi-lingual but no `language` is provided.")
if not self.is_multi_speaker and speaker is not None:
if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs:
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but `language` is provided.")
@ -499,6 +507,7 @@ class TTS:
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
**kwargs,
):
"""Convert text to speech.
@ -520,12 +529,13 @@ class TTS:
Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
Defaults to None.
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed)
self._check_arguments(
speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
)
if self.csapi is not None:
return self.tts_coqui_studio(
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed
)
wav = self.synthesizer.tts(
text=text,
speaker_name=speaker,
@ -535,6 +545,7 @@ class TTS:
style_wav=None,
style_text=None,
reference_speaker_name=None,
**kwargs,
)
return wav
@ -547,6 +558,7 @@ class TTS:
emotion: str = "Neutral",
speed: float = 1.0,
file_path: str = "output.wav",
**kwargs,
):
"""Convert text to speech.
@ -569,13 +581,13 @@ class TTS:
file_path (str, optional):
Output file path. Defaults to "output.wav".
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav)
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
if self.csapi is not None:
return self.tts_coqui_studio(
text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
self.synthesizer.save_wav(wav=wav, path=file_path)
return file_path

View File

@ -274,6 +274,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
help="Target audio file to convert in the voice of the source_wav",
)
parser.add_argument(
"--voice_dir",
type=str,
default=None,
help="Voice dir for tortoise model",
)
args = parser.parse_args()
# print the description if either text or list_models is not set
@ -306,6 +313,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
encoder_config_path = None
vc_path = None
vc_config_path = None
model_dir = None
# CASE1 #list : list pre-trained TTS models
if args.list_models:
@ -335,7 +343,6 @@ If you don't specify any models, then it uses LJSpeech based English model.
# CASE4: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
# tts model
if model_item["model_type"] == "tts_models":
tts_path = model_path
@ -348,6 +355,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
vc_path = model_path
vc_config_path = config_path
# tts model with multiple files to be loaded from the directory path
if isinstance(model_item["github_rls_url"], list):
model_dir = model_path
tts_path = None
tts_config_path = None
args.vocoder_name = None
# load vocoder
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
@ -379,6 +393,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
encoder_config_path,
vc_path,
vc_config_path,
model_dir,
args.voice_dir,
args.use_cuda,
)
@ -427,6 +443,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
source_wav=args.source_wav,
target_wav=args.target_wav,
)
elif model_dir is not None:
wav = synthesizer.tts(args.text, speaker_name=args.speaker_idx)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -0,0 +1,87 @@
from dataclasses import dataclass, field
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig
@dataclass
class TortoiseConfig(BaseTTSConfig):
"""Defines parameters for Tortoise TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (TortoiseArgs):
Model architecture arguments. Defaults to `TortoiseArgs()`.
audio (TortoiseAudioConfig):
Audio processing configuration. Defaults to `TortoiseAudioConfig()`.
model_dir (str):
Path to the folder that has all the Tortoise models. Defaults to None.
temperature (float):
Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`.
length_penalty (float):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length,
which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative),
length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
reperation_penalty (float):
The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`.
top_p (float):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
Defaults to `0.8`.
cond_free_k (float):
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
diffusion_temperature (float):
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
are the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to `1.0`.
num_autoregressive_samples (int):
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
Defaults to `16`.
diffusion_iterations (int):
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
however. Defaults to `30`.
sampler (str):
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs.tortoise_config import TortoiseConfig
>>> config = TortoiseConfig()
"""
model: str = "tortoise"
# model specific params
model_args: TortoiseArgs = field(default_factory=TortoiseArgs)
audio: TortoiseAudioConfig = TortoiseAudioConfig()
model_dir: str = None
# settings
temperature: float = 0.2
length_penalty: float = 1.0
repetition_penalty: float = 2.0
top_p: float = 0.8
cond_free_k: float = 2.0
diffusion_temperature: float = 1.0
# inference params
num_autoregressive_samples: int = 16
diffusion_iterations: int = 30
sampler: str = "ddim"

View File

@ -0,0 +1,433 @@
import functools
import math
import os
import fsspec
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, rel_pos=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
do_checkpoint=True,
relative_pos_embeddings=False,
):
super().__init__()
self.channels = channels
self.do_checkpoint = do_checkpoint
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(
scale=(channels // self.num_heads) ** 0.5,
causal=False,
heads=num_heads,
num_buckets=32,
max_distance=64,
)
else:
self.relative_pos_embeddings = None
def forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
"""
def __init__(self, channels, use_conv, out_channels=None, factor=4):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.factor = factor
if use_conv:
ksize = 5
pad = 2
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
def forward(self, x):
assert x.shape[1] == self.channels
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
"""
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = factor
if use_conv:
self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad)
else:
assert self.channels == self.out_channels
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
up=False,
down=False,
kernel_size=3,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False)
self.x_upd = Upsample(channels, False)
elif down:
self.h_upd = Downsample(channels, False)
self.x_upd = Downsample(channels, False)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
def forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AudioMiniEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
base_channels=128,
depth=2,
resnet_blocks=2,
attn_blocks=4,
num_attn_heads=4,
dropout=0,
downsample_factor=2,
kernel_size=3,
):
super().__init__()
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
ch = base_channels
res = []
for l in range(depth):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(
AttentionBlock(
embedding_dim,
num_attn_heads,
)
)
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
h = self.attn(h)
return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth"
class TorchMelSpectrogram(nn.Module):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=22050,
normalize=False,
mel_norm_file=DEFAULT_MEL_NORM_FILE,
):
super().__init__()
# These are the default tacotron values for the MEL spectrogram.
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
power=2,
normalized=normalize,
sample_rate=self.sampling_rate,
f_min=self.mel_fmin,
f_max=self.mel_fmax,
n_mels=self.n_mel_channels,
norm="slaney",
)
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f)
else:
self.mel_norms = None
def forward(self, inp):
if (
len(inp.shape) == 3
): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
return mel
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return partial(x, *args)
class CheckpointedXTransformerEncoder(nn.Module):
"""
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
self.needs_permute = needs_permute
self.exit_permute = exit_permute
if not checkpoint:
return
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
if self.needs_permute:
x = x.permute(0, 2, 1)
h = self.transformer(x, **kwargs)
if self.exit_permute:
h = h.permute(0, 2, 1)
return h
class TypicalLogitsWarper(LogitsWarper):
def __init__(
self,
mass: float = 0.9,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
):
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

View File

@ -0,0 +1,177 @@
import os
from glob import glob
from typing import Dict, List
import librosa
import numpy as np
import torch
import torchaudio
from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
if data.dtype == np.int32:
norm_fix = 2**31
elif data.dtype == np.int16:
norm_fix = 2**15
elif data.dtype == np.float16 or data.dtype == np.float32:
norm_fix = 1.0
else:
raise NotImplementedError(f"Provided data dtype not supported: {data.dtype}")
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def check_audio(audio, audiopath: str):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
audio.clip_(-1, 1)
def read_audio_file(audiopath: str):
if audiopath[-4:] == ".wav":
audio, lsr = load_wav_to_torch(audiopath)
elif audiopath[-4:] == ".mp3":
audio, lsr = librosa.load(audiopath, sr=None)
audio = torch.FloatTensor(audio)
else:
assert False, f"Unsupported audio format provided: {audiopath[-4:]}"
# Remove any channel data.
if len(audio.shape) > 1:
if audio.shape[0] < 5:
audio = audio[0]
else:
assert audio.shape[1] < 5
audio = audio[:, 0]
return audio, lsr
def load_required_audio(audiopath: str):
audio, lsr = read_audio_file(audiopath)
audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)]
for audio in audios:
check_audio(audio, audiopath)
return [audio.unsqueeze(0) for audio in audios]
def load_audio(audiopath, sampling_rate):
audio, lsr = read_audio_file(audiopath)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
check_audio(audio, audiopath)
return audio.unsqueeze(0)
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
def get_voices(extra_voice_dirs: List[str] = []):
dirs = extra_voice_dirs
voices: Dict[str, List[str]] = {}
for d in dirs:
subs = os.listdir(d)
for sub in subs:
subj = os.path.join(d, sub)
if os.path.isdir(subj):
voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth"))
return voices
def load_voice(voice: str, extra_voice_dirs: List[str] = []):
if voice == "random":
return None, None
voices = get_voices(extra_voice_dirs)
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0])
else:
conds = []
for cond_path in paths:
c = load_required_audio(cond_path)
conds.append(c)
return conds, None
def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
latents = []
clips = []
for voice in voices:
if voice == "random":
if len(voices) > 1:
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None
clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None:
assert (
len(latents) == 0
), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
clips.extend(clip)
elif clip is None:
assert (
len(clips) == 0
), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this."
latents.append(latent)
if len(latents) == 0:
return clips, None
else:
latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0)
latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0)
latents = (latents_0, latents_1)
return None, latents
def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
stft = TorchSTFT(
n_fft=1024,
hop_length=256,
win_length=1024,
use_mel=True,
n_mels=100,
sample_rate=24000,
mel_fmin=0,
mel_fmax=12000,
)
stft = stft.to(device)
mel = stft(wav)
mel = dynamic_range_compression(mel)
if do_normalization:
mel = normalize_tacotron_mel(mel)
return mel

View File

@ -0,0 +1,631 @@
# AGPL: a notification must be added stating that changes have been made to that file.
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
def _p(t):
return t and (len(t), len(t[0]), t[0][0].shape) # kv_cache debug
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan // 8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan // 8, chan),
)
def forward(self, x):
return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - mel_len, attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False,
mean=False,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
self.mean = mean
def forward(self, x):
h = self.init(x)
h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=0.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len,
n_ctx=max_mel_seq_len + max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing,
)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return (
gpt,
LearnedPositionEmbeddings(max_mel_seq_len, model_dim),
LearnedPositionEmbeddings(max_text_seq_len, model_dim),
None,
None,
)
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(
nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0, 2, 1)
class UnifiedVoice(nn.Module):
def __init__(
self,
layers=8,
model_dim=512,
heads=8,
max_text_tokens=120,
max_mel_tokens=250,
max_conditioning_inputs=1,
mel_length_compression=1024,
number_text_tokens=256,
start_text_token=None,
number_mel_codes=8194,
start_mel_token=8192,
stop_mel_token=8193,
train_solo_embeddings=False,
use_mel_codes_as_input=True,
checkpointing=True,
types=1,
):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
(
self.gpt,
self.mel_pos_embedding,
self.text_pos_embedding,
self.mel_layer_pos_embedding,
self.text_layer_pos_embedding,
) = build_hf_gpt_transformer(
layers,
model_dim,
heads,
self.max_mel_tokens + 2 + self.max_conditioning_inputs,
self.max_text_tokens + 2,
checkpointing,
)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]
if use_mel_codes_as_input:
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=0.02)
def post_init_gpt2_config(self, kv_cache=True):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head,
kv_cache=kv_cache,
)
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
# self.inference_model.save_pretrained("")
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc")
for b in range(len(mel_lengths)):
actual_end = (
mel_lengths[b] + 1
) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def get_logits(
self,
speech_conditioning_inputs,
first_inputs,
first_head,
second_inputs=None,
second_head=None,
get_attns=False,
return_latent=False,
):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc)
if return_latent:
return (
enc[
:,
speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1],
],
enc[:, -second_inputs.shape[1] :],
)
first_logits = enc[:, : first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1] :]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def forward(
self,
speech_conditioning_latent,
text_inputs,
text_lengths,
mel_codes,
wav_lengths,
types=None,
text_first=True,
raw_mels=None,
return_attentions=False,
return_latent=False,
clip_inputs=True,
):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,1024)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
"""
# Types are expressed by expanding the text embedding space.
if types is not None:
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None:
raw_mels = raw_mels[:, :, : max_mel_len * 4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
conds = speech_conditioning_latent.unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token
)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 8))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first:
text_logits, mel_logits = self.get_logits(
conds,
text_emb,
self.text_head,
mel_emb,
self.mel_head,
get_attns=return_attentions,
return_latent=return_latent,
)
if return_latent:
return mel_logits[
:, :-2
] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else:
mel_logits, text_logits = self.get_logits(
conds,
mel_emb,
self.mel_head,
text_emb,
self.text_head,
get_attns=return_attentions,
return_latent=return_latent,
)
if return_latent:
return text_logits[
:, :-2
] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference_speech(
self,
speech_conditioning_latent,
text_inputs,
input_tokens=None,
num_return_sequences=1,
max_generate_length=None,
typical_sampling=False,
typical_mass=0.9,
**hf_generate_kwargs,
):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full(
(
emb.shape[0],
conds.shape[1] + emb.shape[1],
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
fake_inputs[:, -1] = self.start_mel_token
trunc_index = fake_inputs.shape[1]
if input_tokens is None:
inputs = fake_inputs
else:
assert (
num_return_sequences % input_tokens.shape[0] == 0
), "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = (
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
) # TODO disable this
max_length = (
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
)
gen = self.inference_model.generate(
inputs,
bos_token_id=self.start_mel_token,
pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
max_length=max_length,
logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
**hf_generate_kwargs,
)
return gen[:, trunc_index:]
if __name__ == "__main__":
gpt = UnifiedVoice(
model_dim=256,
heads=4,
train_solo_embeddings=True,
use_mel_codes_as_input=True,
max_conditioning_inputs=4,
)
l = gpt(
torch.randn(2, 3, 80, 800),
torch.randint(high=120, size=(2, 120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2, 250)),
torch.tensor([250 * 256, 195 * 256]),
)
gpt.text_forward(
torch.randn(2, 80, 800),
torch.randint(high=50, size=(2, 80)),
torch.tensor([32, 80]),
)

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
kernel_size=3,
do_checkpoint=True,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.do_checkpoint = do_checkpoint
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding)
else:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AudioMiniEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
base_channels=128,
depth=2,
resnet_blocks=2,
attn_blocks=4,
num_attn_heads=4,
dropout=0,
downsample_factor=2,
kernel_size=3,
):
super().__init__()
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
ch = base_channels
res = []
self.layers = depth
for l in range(depth):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
for blk in self.attn:
h = blk(h)
return h[:, :, 0]
class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label
def forward(self, x, labels=None):
h = self.enc(x)
logits = self.head(h)
if labels is None:
return logits
else:
if self.distribute_zero_label:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(
oh_labels,
dtype=torch.float,
fill_value=0.2 / (self.num_classes - 1),
)
zero_extra_mass[:, 0] = -0.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
else:
oh_labels = labels
loss = nn.functional.cross_entropy(logits, oh_labels)
return loss

View File

@ -0,0 +1,159 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from TTS.tts.layers.tortoise.arch_utils import CheckpointedXTransformerEncoder
from TTS.tts.layers.tortoise.transformer import Transformer
from TTS.tts.layers.tortoise.xtransformers import Encoder
def exists(val):
return val is not None
def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
class CLVP(nn.Module):
"""
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
transcribed text.
Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
"""
def __init__(
self,
*,
dim_text=512,
dim_speech=512,
dim_latent=512,
num_text_tokens=256,
text_enc_depth=6,
text_seq_len=120,
text_heads=8,
num_speech_tokens=8192,
speech_enc_depth=6,
speech_heads=8,
speech_seq_len=250,
text_mask_percentage=0,
voice_mask_percentage=0,
wav_token_compression=1024,
use_xformers=False,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
if use_xformers:
self.text_transformer = CheckpointedXTransformerEncoder(
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
attn_layers=Encoder(
dim=dim_text,
depth=text_enc_depth,
heads=text_heads,
ff_dropout=0.1,
ff_mult=2,
attn_dropout=0.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
),
)
self.speech_transformer = CheckpointedXTransformerEncoder(
needs_permute=False,
exit_permute=False,
max_seq_len=-1,
attn_layers=Encoder(
dim=dim_speech,
depth=speech_enc_depth,
heads=speech_heads,
ff_dropout=0.1,
ff_mult=2,
attn_dropout=0.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
),
)
else:
self.text_transformer = Transformer(
causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads
)
self.speech_transformer = Transformer(
causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads
)
self.temperature = nn.Parameter(torch.tensor(1.0))
self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression
self.xformers = use_xformers
if not use_xformers:
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
def forward(self, text, speech_tokens, return_loss=False):
b, device = text.shape[0], text.device
if self.training:
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
else:
text_mask = torch.ones_like(text.float()).bool()
voice_mask = torch.ones_like(speech_tokens.float()).bool()
text_emb = self.text_emb(text)
speech_emb = self.speech_emb(speech_tokens)
if not self.xformers:
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
return sim
sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
if __name__ == "__main__":
clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
clip(
torch.randint(0, 256, (2, 120)),
torch.tensor([50, 100]),
torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=True,
)
nonloss = clip(
torch.randint(0, 256, (2, 120)),
torch.tensor([50, 100]),
torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=False,
)
print(nonloss.shape)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,415 @@
import math
import random
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, normalization
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepBlock(nn.Module):
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class ResBlock(TimestepBlock):
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
dims=2,
kernel_size=3,
efficient_config=True,
use_scale_shift_norm=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_scale_shift_norm = use_scale_shift_norm
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
eff_kernel = 1 if efficient_config else 3
eff_padding = 0 if efficient_config else 1
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionLayer(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = ResBlock(
model_channels,
model_channels,
dropout,
model_channels,
dims=1,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
return self.attn(y)
class DiffusionTts(nn.Module):
def __init__(
self,
model_channels=512,
num_layers=8,
in_channels=100,
in_latent_channels=512,
in_tokens=8193,
out_channels=200, # mean and variance
dropout=0,
use_fp16=False,
num_heads=16,
# Parameters for regularization.
layer_drop=0.1,
unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training.
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.dropout = dropout
self.num_heads = num_heads
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.layer_drop = layer_drop
self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
self.time_embed = nn.Sequential(
nn.Linear(model_channels, model_channels),
nn.SiLU(),
nn.Linear(model_channels, model_channels),
)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
self.contextual_embedder = nn.Sequential(
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
DiffusionLayer(model_channels, dropout, num_heads),
)
self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList(
[DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]
+ [
ResBlock(
model_channels,
model_channels,
dropout,
dims=1,
use_scale_shift_norm=True,
)
for _ in range(3)
]
)
self.out = nn.Sequential(
normalization(model_channels),
nn.SiLU(),
nn.Conv1d(model_channels, out_channels, 3, padding=1),
)
def get_grad_norm_parameter_groups(self):
groups = {
"minicoder": list(self.contextual_embedder.parameters()),
"layers": list(self.layers.parameters()),
"code_converters": list(self.code_embedding.parameters())
+ list(self.code_converter.parameters())
+ list(self.latent_conditioner.parameters())
+ list(self.latent_conditioner.parameters()),
"timestep_integrator": list(self.conditioning_timestep_integrator.parameters())
+ list(self.integrating_conv.parameters()),
"time_embed": list(self.time_embed.parameters()),
}
return groups
def get_conditioning(self, conditioning_input):
speech_conditioning_input = (
conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
conds = conds.mean(dim=-1)
return conds
def timestep_independent(
self,
aligned_conditioning,
conditioning_latent,
expected_seq_len,
return_code_pred,
):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_conditioner(aligned_conditioning)
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = (
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage
)
code_emb = torch.where(
unconditioned_batches,
self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb,
)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest")
if not return_code_pred:
return expanded_code_emb
else:
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(
self,
x,
timesteps,
aligned_conditioning=None,
conditioning_latent=None,
precomputed_aligned_embeddings=None,
conditioning_free=False,
return_code_pred=False,
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_latent: a pre-computed conditioning latent; see get_conditioning().
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (
aligned_conditioning is not None and conditioning_latent is not None
)
assert not (
return_code_pred and precomputed_aligned_embeddings is not None
) # These two are mutually exclusive.
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(
aligned_conditioning, conditioning_latent, x.shape[-1], True
)
if is_latent(aligned_conditioning):
unused_params.extend(
list(self.code_converter.parameters()) + list(self.code_embedding.parameters())
)
else:
unused_params.extend(list(self.latent_conditioner.parameters()))
unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
x = self.inp_block(x)
x = torch.cat([x, code_emb], dim=1)
x = self.integrating_conv(x)
for i, lyr in enumerate(self.layers):
# Do layer drop where applicable. Do not drop first and last layers.
if (
self.training
and self.layer_drop > 0
and i != 0
and i != (len(self.layers) - 1)
and random.random() < self.layer_drop
):
unused_params.extend(list(lyr.parameters()))
else:
# First and last blocks will have autocast disabled for improved precision.
with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
x = lyr(x, time_emb)
x = x.float()
out = self.out(x)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
if return_code_pred:
return out, mel_pred
return out
if __name__ == "__main__":
clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2, 388, 512)
aligned_sequence = torch.randint(0, 8192, (2, 100))
cond = torch.randn(2, 100, 400)
ts = torch.LongTensor([600, 600])
model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5)
# Test with latent aligned conditioning
# o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,55 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim),
negative_slope=negative_slope,
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
return out
class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels)
)
self.channels = channels
def forward(self, ref):
r = torch.randn(ref.shape[0], self.channels, device=ref.device)
y = self.layers(r)
return y
if __name__ == "__main__":
model = RandomLatentConverter(512)
model(torch.randn(5, 512))

View File

@ -0,0 +1,34 @@
import os
import torch
from tokenizers import Tokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
)
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt):
txt = english_cleaners(txt)
return txt
def encode(self, txt):
txt = self.preprocess_text(txt)
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
txt = txt.replace("[SPACE]", " ")
txt = txt.replace("[STOP]", "")
txt = txt.replace("[UNK]", "")
return txt

View File

@ -0,0 +1,229 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth=1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim=-1, alpha=32**2):
t = t / alpha
t = t - torch.amax(t, dim=dim, keepdim=True).detach()
return (t * alpha).softmax(dim=dim)
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# classes
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route={}, layer_dropout=0.0):
super().__init__()
assert all(
len(route) == len(layers) for route in args_route.values()
), "each argument route map must have the same depth as the number of sequential layers"
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim=self.dim, keepdim=True).detach()
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn, sandwich=False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
x = self.fn(x, **kwargs)
return self.norm_out(x)
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout=0.0, mult=4.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)
def forward(self, x):
return self.net(x)
# Attention
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head**-0.5
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
def forward(self, x, mask=None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
q = q * self.scale
dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, "b j -> b () () j")
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
return out
# main transformer class
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
causal=True,
heads=8,
dim_head=64,
ff_mult=4,
attn_dropout=0.0,
ff_dropout=0.0,
sparse_attn=False,
sandwich_norm=False,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(
dim,
causal=causal,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
)
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
layers.append(
nn.ModuleList(
[
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)),
]
)
)
execute_type = SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {"mask": route_attn}
self.layers = execute_type(layers, args_route=attn_route_map)
def forward(self, x, **kwargs):
return self.layers(x, **kwargs)

View File

@ -0,0 +1,46 @@
import os
from urllib import request
from tqdm import tqdm
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/"
MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
"classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth",
"clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth",
"diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth",
"vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth",
"rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth",
"rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth",
}
def download_models(specific_models=None):
"""
Call to download all the models that Tortoise uses.
"""
os.makedirs(MODELS_DIR, exist_ok=True)
for model_name, url in MODELS.items():
if specific_models is not None and model_name not in specific_models:
continue
model_path = os.path.join(MODELS_DIR, model_name)
if os.path.exists(model_path):
continue
print(f"Downloading {model_name} from {url}...")
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
print("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR):
"""
Get path to given model, download it if it doesn't exist.
"""
if model_name not in MODELS:
raise ValueError(f"Model {model_name} not found in available models.")
model_path = os.path.join(models_dir, model_name)
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
download_models([model_name])
return model_path

View File

@ -0,0 +1,401 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
MAX_WAV_VALUE = 32768.0
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__(
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
)
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
)
),
)
self.conv_blocks = nn.ModuleList()
for dilation in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
conv_kernel_size,
padding=dilation * (conv_kernel_size - 1) // 2,
dilation=dilation,
)
),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(
output, k, b, hop_size=self.cond_hop_length
) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]
) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
class UnivNetGenerator(nn.Module):
"""
UnivNet Generator
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
"""
def __init__(
self,
noise_dim=64,
channel_size=32,
dilations=[1, 3, 9, 27],
strides=[8, 8, 4],
lReLU_slope=0.2,
kpnet_conv_size=3,
# Below are MEL configurations options that this generator requires.
hop_length=256,
n_mel_channels=100,
):
super(UnivNetGenerator, self).__init__()
self.mel_channel = n_mel_channels
self.noise_dim = noise_dim
self.hop_length = hop_length
channel_size = channel_size
kpnet_conv_size = kpnet_conv_size
self.res_stack = nn.ModuleList()
hop_length = 1
for stride in strides:
hop_length = stride * hop_length
self.res_stack.append(
LVCBlock(
channel_size,
n_mel_channels,
stride=stride,
dilations=dilations,
lReLU_slope=lReLU_slope,
cond_hop_length=hop_length,
kpnet_conv_size=kpnet_conv_size,
)
)
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
def forward(self, c, z):
"""
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
z (Tensor): the noise sequence (batch, noise_dim, in_length)
"""
z = self.conv_pre(z) # (B, c_g, L)
for res_block in self.res_stack:
res_block.to(z.device)
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
z = self.conv_post(z) # (B, 1, L * 256)
return z
def eval(self, inference=False):
super(UnivNetGenerator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
for res_block in self.res_stack:
res_block.remove_weight_norm()
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, : -(self.hop_length * 10)]
audio = audio.clamp(min=-1, max=1)
return audio
@dataclass
class VocType:
constructor: Callable[[], nn.Module]
model_path: str
subkey: Optional[str] = None
def optionally_index(self, model_dict):
if self.subkey is not None:
return model_dict[self.subkey]
return model_dict
class VocConf(Enum):
Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g")
if __name__ == "__main__":
model = UnivNetGenerator()
c = torch.randn(3, 100, 10)
z = torch.randn(3, 64, 10)
print(c.shape)
y = model(c, z)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@ -0,0 +1,150 @@
import torch
import torchaudio
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC
def max_alignment(s1, s2, skip_character="~", record=None):
"""
A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
used to replace that character.
Finally got to use my DP skills!
"""
if record is None:
record = {}
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
if len(s1) == 0:
return ""
if len(s2) == 0:
return skip_character * len(s1)
if s1 == s2:
return s1
if s1[0] == s2[0]:
return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
take_s1_key = (len(s1), len(s2) - 1)
if take_s1_key in record:
take_s1, take_s1_score = record[take_s1_key]
else:
take_s1 = max_alignment(s1, s2[1:], skip_character, record)
take_s1_score = len(take_s1.replace(skip_character, ""))
record[take_s1_key] = (take_s1, take_s1_score)
take_s2_key = (len(s1) - 1, len(s2))
if take_s2_key in record:
take_s2, take_s2_score = record[take_s2_key]
else:
take_s2 = max_alignment(s1[1:], s2, skip_character, record)
take_s2_score = len(take_s2.replace(skip_character, ""))
record[take_s2_key] = (take_s2, take_s2_score)
return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
class Wav2VecAlignment:
"""
Uses wav2vec2 to perform audio<->text alignment.
"""
def __init__(self, device="cuda"):
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols")
self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000):
orig_len = audio.shape[-1]
with torch.no_grad():
self.model = self.model.to(self.device)
audio = audio.to(self.device)
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
logits = self.model(clip_norm).logits
self.model = self.model.cpu()
logits = logits[0]
pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
fixed_expectation = max_alignment(expected_text.lower(), pred_string)
w2v_compression = orig_len // logits.shape[0]
expected_tokens = self.tokenizer.encode(fixed_expectation)
expected_chars = list(fixed_expectation)
if len(expected_tokens) == 1:
return [0] # The alignment is simple; there is only one token.
expected_tokens.pop(0) # The first token is a given.
expected_chars.pop(0)
alignments = [0]
def pop_till_you_win():
if len(expected_tokens) == 0:
return None
popped = expected_tokens.pop(0)
popped_char = expected_chars.pop(0)
while popped_char == "~":
alignments.append(-1)
if len(expected_tokens) == 0:
return None
popped = expected_tokens.pop(0)
popped_char = expected_chars.pop(0)
return popped
next_expected_token = pop_till_you_win()
for i, logit in enumerate(logits):
top = logit.argmax()
if next_expected_token == top:
alignments.append(i * w2v_compression)
if len(expected_tokens) > 0:
next_expected_token = pop_till_you_win()
else:
break
pop_till_you_win()
if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
torch.save([audio, expected_text], "alignment_debug.pth")
assert False, (
"Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to"
"your current working directory. Please report this along with the file so it can get fixed."
)
# Now fix up alignments. Anything with -1 should be interpolated.
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
for i in range(len(alignments)):
if alignments[i] == -1:
for j in range(i + 1, len(alignments)):
if alignments[j] != -1:
next_found_token = j
break
for j in range(i, next_found_token):
gap = alignments[next_found_token] - alignments[i - 1]
alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1]
return alignments[:-1]
def redact(self, audio, expected_text, audio_sample_rate=24000):
if "[" not in expected_text:
return audio
splitted = expected_text.split("[")
fully_split = [splitted[0]]
for spl in splitted[1:]:
assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split("]"))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
non_redacted_intervals = []
last_point = 0
for i in range(len(fully_split)):
if i % 2 == 0:
end_interval = max(0, last_point + len(fully_split[i]) - 1)
non_redacted_intervals.append((last_point, end_interval))
last_point += len(fully_split[i])
bare_text = "".join(fully_split)
alignments = self.align(audio, bare_text, audio_sample_rate)
output_audio = []
for nri in non_redacted_intervals:
start, stop = nri
output_audio.append(audio[:, alignments[start] : alignments[stop]])
return torch.cat(output_audio, dim=-1)

File diff suppressed because it is too large Load Diff

View File

@ -10,5 +10,5 @@ def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None)
MyModel = find_module("TTS.tts.models", config.base_model.lower())
else:
MyModel = find_module("TTS.tts.models", config.model.lower())
model = MyModel.init_from_config(config, samples)
model = MyModel.init_from_config(config=config, samples=samples)
return model

900
TTS/tts/models/tortoise.py Normal file
View File

@ -0,0 +1,900 @@
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
from time import time
import torch
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel
from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP
from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
tp = t[..., :length]
if t.shape[-1] == length:
tp = t
elif t.shape[-1] < length:
tp = F.pad(t, (0, length - t.shape[-1]))
return tp
def deterministic_state(seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
cond_free=True,
cond_free_k=1,
sampler="ddim",
):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k,
sampler=sampler,
)
def format_conditioning(clip, cond_length=132300, device="cuda"):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
gap = clip.shape[-1] - cond_length
if gap < 0:
clip = F.pad(clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
clip = clip[:, rand_start : rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).to(device)
def fix_autoregressive_output(codes, stop_token, complain=True):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates (which has no padding or end).
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes.
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
if complain:
print(
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
"too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
"try breaking up your input text."
)
return codes
codes[stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[stm:] = 83
if stm - 3 < codes.shape[0]:
codes[-3] = 45
codes[-2] = 45
codes[-1] = 248
return codes
def do_spectrogram_diffusion(
diffusion_model,
diffuser,
latents,
conditioning_latents,
temperature=1,
verbose=True,
):
"""
Uses the specified diffusion model to convert discrete codes into a spectrogram.
"""
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=verbose,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
def classify_audio_clip(clip, model_dir):
"""
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
:param clip: torch tensor containing audio waveform data (get it from load_audio)
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
"""
classifier = AudioMiniEncoderWithClassifierHead(
2,
spec_dim=1,
embedding_dim=512,
depth=5,
downsample_factor=4,
resnet_blocks=2,
attn_blocks=4,
num_attn_heads=4,
base_channels=32,
dropout=0,
kernel_size=5,
distribute_zero_label=False,
)
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
def pick_best_batch_size_for_gpu():
"""
Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give
you a good shot.
"""
if torch.cuda.is_available():
_, available = torch.cuda.mem_get_info()
availableGb = available / (1024**3)
batch_size = 1
if availableGb > 14:
batch_size = 16
elif availableGb > 10:
batch_size = 8
elif availableGb > 7:
batch_size = 4
return batch_size
@dataclass
class TortoiseAudioConfig(Coqpit):
sample_rate: int = 22050
diffusion_sample_rate: int = 24000
output_sample_rate: int = 24000
@dataclass
class TortoiseArgs(Coqpit):
"""A dataclass to represent Tortoise model arguments that define the model structure.
Args:
autoregressive_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
high_vram (bool, optional): Whether to use high VRAM. Defaults to False.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
ar_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
diff_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
For UnifiedVoice model:
ar_max_mel_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
ar_max_conditioning_inputs (int, optional): The maximum conditioning inputs for the autoregressive model. Defaults to 2.
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
ar_model_dim (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
ar_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
ar_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
For DiffTTS model:
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
For ConditionalLatentVariablePerseq model:
clvp_dim_text (int): The dimension of the text input for the CLVP module. Defaults to 768.
clvp_dim_speech (int): The dimension of the speech input for the CLVP module. Defaults to 768.
clvp_dim_latent (int): The dimension of the latent representation for the CLVP module. Defaults to 768.
clvp_num_text_tokens (int): The number of text tokens used by the CLVP module. Defaults to 256.
clvp_text_enc_depth (int): The depth of the text encoder in the CLVP module. Defaults to 20.
clvp_text_seq_len (int): The maximum sequence length of the text input for the CLVP module. Defaults to 350.
clvp_text_heads (int): The number of attention heads used by the text encoder in the CLVP module. Defaults to 12.
clvp_num_speech_tokens (int): The number of speech tokens used by the CLVP module. Defaults to 8192.
clvp_speech_enc_depth (int): The depth of the speech encoder in the CLVP module. Defaults to 20.
clvp_speech_heads (int): The number of attention heads used by the speech encoder in the CLVP module. Defaults to 12.
clvp_speech_seq_len (int): The maximum sequence length of the speech input for the CLVP module. Defaults to 430.
clvp_use_xformers (bool): A flag indicating whether the model uses transformers in the CLVP module. Defaults to True.
duration_const (int): A constant value used in the model. Defaults to 102400.
"""
autoregressive_batch_size: int = 1
enable_redaction: bool = True
high_vram: bool = False
kv_cache: bool = True
ar_checkpoint: str = None
clvp_checkpoint: str = None
diff_checkpoint: str = None
num_chars: int = 255
vocoder: VocType = VocConf.Univnet
# UnifiedVoice params
ar_max_mel_tokens: int = 604
ar_max_text_tokens: int = 402
ar_max_conditioning_inputs: int = 2
ar_layers: int = 30
ar_model_dim: int = 1024
ar_heads: int = 16
ar_number_text_tokens: int = 255
ar_start_text_token: int = 255
ar_checkpointing: bool = False
ar_train_solo_embeddings: bool = False
# DiffTTS params
diff_model_channels: int = 1024
diff_num_layers: int = 10
diff_in_channels: int = 100
diff_out_channels: int = 200
diff_in_latent_channels: int = 1024
diff_in_tokens: int = 8193
diff_dropout: int = 0
diff_use_fp16: bool = False
diff_num_heads: int = 16
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# clvp params
clvp_dim_text: int = 768
clvp_dim_speech: int = 768
clvp_dim_latent: int = 768
clvp_num_text_tokens: int = 256
clvp_text_enc_depth: int = 20
clvp_text_seq_len: int = 350
clvp_text_heads: int = 12
clvp_num_speech_tokens: int = 8192
clvp_speech_enc_depth: int = 20
clvp_speech_heads: int = 12
clvp_speech_seq_len: int = 430
clvp_use_xformers: bool = True
# constants
duration_const: int = 102400
class Tortoise(BaseTTS):
"""Tortoise model class.
Currently only supports inference.
Examples:
>>> from TTS.tts.configs.tortoise_config import TortoiseConfig
>>> from TTS.tts.models.tortoise import Tortoise
>>> config = TortoiseConfig()
>>> model = Tortoise.inif_from_config(config)
>>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True)
"""
def __init__(self, config: Coqpit):
super().__init__(config, ap=None, tokenizer=None)
self.config = config
self.ar_checkpoint = self.args.ar_checkpoint
self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed
self.models_dir = config.model_dir
self.autoregressive_batch_size = (
pick_best_batch_size_for_gpu()
if self.args.autoregressive_batch_size is None
else self.args.autoregressive_batch_size
)
self.enable_redaction = self.args.enable_redaction
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.enable_redaction:
self.aligner = Wav2VecAlignment()
self.tokenizer = VoiceBpeTokenizer()
self.autoregressive = UnifiedVoice(
max_mel_tokens=self.args.ar_max_mel_tokens,
max_text_tokens=self.args.ar_max_text_tokens,
max_conditioning_inputs=self.args.ar_max_conditioning_inputs,
layers=self.args.ar_layers,
model_dim=self.args.ar_model_dim,
heads=self.args.ar_heads,
number_text_tokens=self.args.ar_number_text_tokens,
start_text_token=self.args.ar_start_text_token,
checkpointing=self.args.ar_checkpointing,
train_solo_embeddings=self.args.ar_train_solo_embeddings,
).cpu()
self.diffusion = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
).cpu()
self.clvp = CLVP(
dim_text=self.args.clvp_dim_text,
dim_speech=self.args.clvp_dim_speech,
dim_latent=self.args.clvp_dim_latent,
num_text_tokens=self.args.clvp_num_text_tokens,
text_enc_depth=self.args.clvp_text_enc_depth,
text_seq_len=self.args.clvp_text_seq_len,
text_heads=self.args.clvp_text_heads,
num_speech_tokens=self.args.clvp_num_speech_tokens,
speech_enc_depth=self.args.clvp_speech_enc_depth,
speech_heads=self.args.clvp_speech_heads,
speech_seq_len=self.args.clvp_speech_seq_len,
use_xformers=self.args.clvp_use_xformers,
).cpu()
self.vocoder = self.args.vocoder.value.constructor().cpu()
# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
self.rlg_diffusion = None
if self.args.high_vram:
self.autoregressive = self.autoregressive.to(self.device)
self.diffusion = self.diffusion.to(self.device)
self.clvp = self.clvp.to(self.device)
self.vocoder = self.vocoder.to(self.device)
self.high_vram = self.args.high_vram
@contextmanager
def temporary_cuda(self, model):
if self.high_vram:
yield model
else:
m = model.to(self.device)
yield m
m = model.cpu()
def get_conditioning_latents(
self,
voice_samples,
return_mels=False,
latent_averaging_mode=0,
original_tortoise=False,
):
"""
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
properties.
:param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data.
:param latent_averaging_mode: 0/1/2 for following modes:
0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples
1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks
2 - latents will be generated using (almost) entire voice samples, averaged per voice sample
"""
assert latent_averaging_mode in [
0,
1,
2,
], "latent_averaging mode has to be one of (0, 1, 2)"
with torch.no_grad():
voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples]
auto_conds = []
for ls in voice_samples:
auto_conds.append(format_conditioning(ls[0], device=self.device))
auto_conds = torch.stack(auto_conds, dim=1)
with self.temporary_cuda(self.autoregressive) as ar:
auto_latent = ar.get_conditioning(auto_conds)
diffusion_conds = []
DURS_CONST = self.args.duration_const
for ls in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
if latent_averaging_mode == 0:
sample = pad_or_truncate(sample, DURS_CONST)
cond_mel = wav_to_univnet_mel(
sample.to(self.device),
do_normalization=False,
device=self.device,
)
diffusion_conds.append(cond_mel)
else:
from math import ceil
if latent_averaging_mode == 2:
temp_diffusion_conds = []
for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST]
current_sample = pad_or_truncate(current_sample, DURS_CONST)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
do_normalization=False,
device=self.device,
)
if latent_averaging_mode == 1:
diffusion_conds.append(cond_mel)
elif latent_averaging_mode == 2:
temp_diffusion_conds.append(cond_mel)
if latent_averaging_mode == 2:
diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0))
diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.temporary_cuda(self.diffusion) as diffusion:
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
return auto_latent, diffusion_latent
def get_random_conditioning_latents(self):
# Lazy-load the RLG models.
if self.rlg_auto is None:
self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(
torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
)
)
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
def synthesize(self, text, config, speaker_id="random", extra_voice_dirs=None, **kwargs):
"""Synthesize speech with the given input text.
Args:
text (str): Input text.
config (TortoiseConfig): Config with inference parameters.
speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker.
extra_voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None.
**kwargs: Inference settings. See `inference()`.
Returns:
A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
as latents used at inference.
"""
if extra_voice_dirs is not None:
extra_voice_dirs = [extra_voice_dirs]
voice_samples, conditioning_latents = load_voice(speaker_id, extra_voice_dirs)
else:
voice_samples, conditioning_latents = load_voice(speaker_id)
outputs = self.inference_with_config(
text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs
)
return_dict = {
"wav": outputs["wav"],
"deterministic_seed": outputs["deterministic_seed"],
"text_inputs": outputs["text"],
"voice_samples": outputs["voice_samples"],
"conditioning_latents": outputs["conditioning_latents"],
}
return return_dict
def inference_with_config(self, text, config, **kwargs):
"""
inference with config
#TODO describe in detail
"""
# Use generally found best tuning knobs for generation.
settings = {
"temperature": config.temperature,
"length_penalty": config.length_penalty,
"repetition_penalty": config.repetition_penalty,
"top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"sampler": config.sampler,
}
# Presets are defined here.
presets = {
"single_sample": {
"num_autoregressive_samples": 8,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 10,
"sampler": "ddim",
},
"ultra_fast_old": {
"num_autoregressive_samples": 16,
"diffusion_iterations": 30,
"cond_free": False,
},
"very_fast": {
"num_autoregressive_samples": 32,
"diffusion_iterations": 30,
"sampler": "dpm++2m",
},
"fast": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 50,
"sampler": "ddim",
},
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"standard": {
"num_autoregressive_samples": 5,
"diffusion_iterations": 200,
},
"high_quality": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 400,
},
}
if "preset" in kwargs:
settings.update(presets[kwargs["preset"]])
kwargs.pop("preset")
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, **settings)
def inference(
self,
text,
voice_samples=None,
conditioning_latents=None,
k=1,
verbose=True,
use_deterministic_seed=None,
return_deterministic_state=False,
latent_averaging_mode=0,
# autoregressive generation parameters follow
num_autoregressive_samples=16,
temperature=0.8,
length_penalty=1,
repetition_penalty=2.0,
top_p=0.8,
max_mel_tokens=500,
# diffusion generation parameters follow
diffusion_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
sampler="ddim",
half=True,
original_tortoise=False,
**hf_generate_kwargs,
):
"""
This function produces an audio clip of the given text being spoken with the given reference voice.
Args:
text: (str) Text to be spoken.
voice_samples: (List[Tuple[torch.Tensor]]) List of an arbitrary number of reference clips, which should be tuple-pairs
of torch tensors containing arbitrary kHz waveform data.
conditioning_latents: (Tuple[autoregressive_conditioning_latent, diffusion_conditioning_latent]) A tuple of
(autoregressive_conditioning_latent, diffusion_conditioning_latent), which can be provided in lieu
of voice_samples. This is ignored unless `voice_samples=None`. Conditioning latents can be retrieved
via `get_conditioning_latents()`.
k: (int) The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned.
latent_averaging_mode: (int) 0/1/2 for following modes:
0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples
1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks
2 - latents will be generated using (almost) entire voice samples, averaged per voice sample
verbose: (bool) Whether or not to print log messages indicating the progress of creating a clip. Default=true.
num_autoregressive_samples: (int) Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great".
temperature: (float) The softmax temperature of the autoregressive model.
length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs.
repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce
the incidence of long silences or "uhhhhhhs", etc.
top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs.
max_mel_tokens: (int) Restricts the output length. (0,600] integer. Each unit is 1/20 of a second.
typical_sampling: (bool) Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666
I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning.
typical_mass: (float) The typical_mass parameter from the typical_sampling algorithm.
diffusion_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively
refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however.
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for
each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two
is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
are the "mean" prediction of the diffusion network and will sound bland and smeared.
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer.
Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
Returns:
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
deterministic_seed = deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert (
text_tokens.shape[-1] < 400
), "Too much text provided. Break the text up into separate segments and re-try inference."
if voice_samples is not None:
(
auto_conditioning,
diffusion_conditioning,
_,
_,
) = self.get_conditioning_latents(
voice_samples,
return_mels=True,
latent_averaging_mode=latent_averaging_mode,
original_tortoise=original_tortoise,
)
elif conditioning_latents is not None:
auto_conditioning, diffusion_conditioning = conditioning_latents
else:
(
auto_conditioning,
diffusion_conditioning,
) = self.get_random_conditioning_latents()
auto_conditioning = auto_conditioning.to(self.device)
diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler
)
# in the case of single_sample,
orig_batch_size = self.autoregressive_batch_size
while num_autoregressive_samples % self.autoregressive_batch_size:
self.autoregressive_batch_size //= 2
with torch.no_grad():
samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token
calm_token = (
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half
):
for b in tqdm(range(num_batches), disable=not verbose):
codes = autoregressive.inference_speech(
auto_conditioning,
text_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_mel_tokens,
**hf_generate_kwargs,
)
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
self.autoregressive_batch_size = orig_batch_size # in the case of single_sample
clip_results = []
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half
):
for batch in tqdm(samples, disable=not verbose):
for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
clvp_res = clvp(
text_tokens.repeat(batch.shape[0], 1),
batch,
return_loss=False,
)
clip_results.append(clvp_res)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices]
del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
with self.temporary_cuda(self.autoregressive) as autoregressive:
best_latents = autoregressive(
auto_conditioning.repeat(k, 1),
text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results,
torch.tensor(
[best_results.shape[-1] * self.autoregressive.mel_length_compression],
device=text_tokens.device,
),
return_latent=True,
clip_inputs=False,
)
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
wav_candidates = []
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for code in range(codes.shape[-1]):
if codes[0, code] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :code]
break
with self.temporary_cuda(self.diffusion) as diffusion:
mel = do_spectrogram_diffusion(
diffusion,
diffuser,
latents,
diffusion_conditioning,
temperature=diffusion_temperature,
verbose=verbose,
)
with self.temporary_cuda(self.vocoder) as vocoder:
wav = vocoder.inference(mel)
wav_candidates.append(wav.cpu())
def potentially_redact(clip, text):
if self.enable_redaction:
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return clip
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
if len(wav_candidates) > 1:
res = wav_candidates
else:
res = wav_candidates[0]
return_dict = {
"wav": res,
"deterministic_seed": None,
"text": None,
"voice_samples": None,
"conditioning_latents": None,
}
if return_deterministic_state:
return_dict = {
"wav": res,
"deterministic_seed": deterministic_seed,
"text": text,
"voice_samples": voice_samples,
"conditioning_latents": conditioning_latents,
}
return return_dict
def forward(self):
raise NotImplementedError("Tortoise Training is not implemented")
def eval_step(self):
raise NotImplementedError("Tortoise Training is not implemented")
@staticmethod
def init_from_config(config: "TortoiseConfig", **kwargs): # pylint: disable=unused-argument
return Tortoise(config)
def load_checkpoint(
self,
config,
checkpoint_dir,
ar_checkpoint_path=None,
diff_checkpoint_path=None,
clvp_checkpoint_path=None,
vocoder_checkpoint_path=None,
eval=False,
strict=True,
**kwargs,
): # pylint: disable=unused-argument, redefined-builtin
"""Load a model checkpoints from a directory. This model is with multiple checkpoint files and it
expects to have all the files to be under the given `checkpoint_dir` with the rigth names.
If eval is True, set the model to eval mode.
Args:
config (TortoiseConfig): The model config.
checkpoint_dir (str): The directory where the checkpoints are stored.
ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None.
diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None.
clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None.
vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None.
eval (bool, optional): Whether to set the model to eval mode. Defaults to False.
strict (bool, optional): Whether to load the model strictly. Defaults to True.
"""
if self.models_dir is None:
self.models_dir = checkpoint_dir
ar_path = ar_checkpoint_path or os.path.join(checkpoint_dir, "autoregressive.pth")
diff_path = diff_checkpoint_path or os.path.join(checkpoint_dir, "diffusion_decoder.pth")
clvp_path = clvp_checkpoint_path or os.path.join(checkpoint_dir, "clvp2.pth")
vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth")
if os.path.exists(ar_path):
self.autoregressive.load_state_dict(torch.load(ar_path), strict=strict)
if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path), strict=strict)
if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path), strict=strict)
if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict(
config.model_args.vocoder.value.optionally_index(
torch.load(
vocoder_checkpoint_path,
map_location=torch.device("cpu"),
)
)
)
if eval:
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
self.autoregressive.eval()
self.diffusion.eval()
self.clvp.eval()
self.vocoder.eval()
def train_step(self):
raise NotImplementedError("Tortoise Training is not implemented")

View File

@ -0,0 +1 @@
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}

View File

@ -78,6 +78,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
power=None,
use_htk=False,
mel_norm="slaney",
normalized=False,
):
super().__init__()
self.n_fft = n_fft
@ -96,6 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
self.normalized = normalized
if use_mel:
self._build_mel_basis()
@ -125,7 +127,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.window,
center=True,
pad_mode="reflect", # compatible with audio.py
normalized=False,
normalized=self.normalized,
onesided=True,
return_complex=False,
)

View File

@ -272,10 +272,16 @@ class ModelManager(object):
os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}")
# download from github release
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
# find downloaded files
output_model_path, output_config_path = self._find_files(output_path)
output_model_path = output_path
output_config_path = None
if model != "tortoise-v2":
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
return output_model_path, output_config_path, model_item
@ -415,6 +421,25 @@ class ModelManager(object):
# remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod
def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases"""
for file_url in file_urls:
# download the file
r = requests.get(file_url, stream=True)
# extract the file
bease_filename = file_url.split("/")[-1]
temp_zip_name = os.path.join(output_folder, bease_filename)
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file:
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
file.write(data)
@staticmethod
def _check_dict_key(my_dict, key):
if key in my_dict.keys() and my_dict[key] is not None:

View File

@ -1,3 +1,4 @@
import os
import time
from typing import List
@ -31,6 +32,8 @@ class Synthesizer(object):
encoder_config: str = "",
vc_checkpoint: str = "",
vc_config: str = "",
model_dir: str = "",
voice_dir: str = None,
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
@ -78,7 +81,7 @@ class Synthesizer(object):
self.d_vector_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
self.voice_dir = voice_dir
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
@ -94,6 +97,10 @@ class Synthesizer(object):
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
"""get the sentence segmenter for the given language.
@ -126,6 +133,19 @@ class Synthesizer(object):
if use_cuda:
self.vc_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
self.tts_model.cuda()
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model.
@ -220,6 +240,7 @@ class Synthesizer(object):
style_text=None,
reference_wav=None,
reference_speaker_name=None,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -249,6 +270,9 @@ class Synthesizer(object):
print(sens)
# handle multi-speaker
if "voice_dir" in kwargs:
self.voice_dir = kwargs["voice_dir"]
kwargs.pop("voice_dir")
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
@ -275,7 +299,7 @@ class Synthesizer(object):
else:
speaker_embedding = None
else:
if speaker_name:
if speaker_name and self.voice_dir is None:
raise ValueError(
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
@ -312,29 +336,39 @@ class Synthesizer(object):
)
# compute a new d_vector from the given clip.
if speaker_wav is not None:
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
use_gl = self.vocoder_model is None
if not reference_wav:
for sen in sens:
# synthesize voice
outputs = synthesis(
model=self.tts_model,
text=sen,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
speaker_id=speaker_id,
style_wav=style_wav,
style_text=style_text,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
language_id=language_id,
)
if hasattr(self.tts_model, "synthesize"):
sp_name = "random" if speaker_name is None else speaker_name
outputs = self.tts_model.synthesize(
text=sen,
config=self.tts_config,
speaker_id=sp_name,
extra_voice_dirs=self.voice_dir,
**kwargs,
)
else:
# synthesize voice
outputs = synthesis(
model=self.tts_model,
text=sen,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
speaker_id=speaker_id,
style_wav=style_wav,
style_text=style_text,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
language_id=language_id,
)
waveform = outputs["wav"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
if not use_gl:
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
# denormalize tts output based on tts audio config
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"

View File

@ -51,6 +51,7 @@
models/forward_tts.md
models/tacotron1-2.md
models/overflow.md
model/tortoise.md
.. toctree::
:maxdepth: 2

View File

@ -0,0 +1,94 @@
# Tortoise 🐢
Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input
text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to
the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS.
Big thanks to 👑[@manmay-nakhashi](https://github.com/manmay-nakhashi) who helped us implement Tortoise in 🐸TTS.
Example use:
```python
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.models.tortoise import Tortoise
config = TortoiseConfig()
model = Tortoise.inif_from_config(config)
model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True)
# with random speaker
output_dict = model.synthesize(text, config, speaker_id="random", extra_voice_dirs=None, **kwargs)
# cloning a speaker
output_dict = model.synthesize(text, config, speaker_id="speaker_n", extra_voice_dirs="path/to/speaker_n/", **kwargs)
```
Using 🐸TTS API:
```python
from TTS.api import TTS
tts = TTS("tts_models/en/multi-dataset/tortoise-v2")
# cloning `lj` voice from `TTS/tts/utils/assets/tortoise/voices/lj`
# with custom inference settings overriding defaults.
tts.tts_to_file(text="Hello, my name is Manmay , how are you?",
file_path="output.wav",
voice_dir="TTS/tts/utils/assets/tortoise/voices/",
speaker="lj",
num_autoregressive_samples=1,
diffusion_iterations=10)
# Using presets with the same voice
tts.tts_to_file(text="Hello, my name is Manmay , how are you?",
file_path="output.wav",
voice_dir="TTS/tts/utils/assets/tortoise/voices/",
speaker="lj",
preset="ultra_fast")
# Random voice generation
tts.tts_to_file(text="Hello, my name is Manmay , how are you?",
file_path="output.wav")
```
Using 🐸TTS Command line:
```console
# cloning the `lj` voice
tts --model_name tts_models/en/multi-dataset/tortoise-v2 \
--text "This is an example." \
--out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \
--voice_dir TTS/tts/utils/assets/tortoise/voices/ \
--speaker_idx "lj" \
--progress_bar True
# Random voice generation
tts --model_name tts_models/en/multi-dataset/tortoise-v2 \
--text "This is an example." \
--out_path "/data/speech_synth/coqui-tts/TTS/tests/outputs/output.wav" \
--progress_bar True
```
## Important resources & papers
- Original Repo: https://github.com/neonbjb/tortoise-tts
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
- Univnet: https://arxiv.org/abs/2106.07889
- Latent Diffusion:https://arxiv.org/abs/2112.10752
- DALL-E: https://arxiv.org/abs/2102.12092
## TortoiseConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.tortoise_config.TortoiseConfig
:members:
```
## TortoiseArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.tortoise.TortoiseArgs
:members:
```
## Tortoise Model
```{eval-rst}
.. autoclass:: TTS.tts.models.tortoise.Tortoise
:members:
```

108
notebooks/Tortoise.ipynb Normal file
View File

@ -0,0 +1,108 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4d50310e-f094-42e0-af30-1e42b13ceb95",
"metadata": {},
"outputs": [],
"source": [
"#@title # Setup\n",
"# Imports used through the rest of the notebook.\n",
"import torch\n",
"import torchaudio\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import IPython\n",
"\n",
"from TTS.tts.models.tortoise import TextToSpeech\n",
"from TTS.tts.layers.tortoise.audio_utils import load_audio, load_voice, load_voices\n",
"\n",
"# This will download all the models used by Tortoise from the HuggingFace hub.\n",
"tts = TextToSpeech()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e126c3c3-d90a-492f-b5bb-0d86587f15cc",
"metadata": {},
"outputs": [],
"source": [
"# This is the text that will be spoken.\n",
"text = \"Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?\" #@param {type:\"string\"}\n",
"#@markdown Show code for multiline text input\n",
"# Here's something for the poetically inclined.. (set text=)\n",
"\"\"\"\n",
"Then took the other, as just as fair,\n",
"And having perhaps the better claim,\n",
"Because it was grassy and wanted wear;\n",
"Though as for that the passing there\n",
"Had worn them really about the same,\"\"\"\n",
"\n",
"# Pick a \"preset mode\" to determine quality. Options: {\"ultra_fast\", \"fast\" (default), \"standard\", \"high_quality\"}. See docs in api.py\n",
"preset = \"fast\" #@param [\"ultra_fast\", \"fast\", \"standard\", \"high_quality\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9413f553-5bd0-4820-bad4-edd7fd7d2370",
"metadata": {},
"outputs": [],
"source": [
"%ls ../TTS/tts/utils/assets/tortoise/voices/\n",
"import IPython\n",
"IPython.display.Audio(filename='../TTS/tts/utils/assets/tortoise/voices/lj/1.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96a98ae5-313b-40d1-9311-5a785f2c9a4e",
"metadata": {},
"outputs": [],
"source": [
"#@markdown Pick one of the voices from the output above\n",
"voice = 'lj' #@param {type:\"string\"}\n",
"\n",
"#@markdown Load it and send it through Tortoise.\n",
"voice_samples, conditioning_latents = load_voice(voice)\n",
"gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, \n",
" preset=preset)\n",
"torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)\n",
"IPython.display.Audio('generated.wav')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04e473e5-c489-4a78-aa11-03e89a778ed8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -45,3 +45,8 @@ g2pkk>=0.1.1
bangla==0.0.2
bnnumerizer
bnunicodenormalizer==0.1.1
#deps for tortoise
k_diffusion
einops
transformers

View File

@ -12,6 +12,7 @@ is_coqui_available = os.environ.get("COQUI_STUDIO_TOKEN")
if is_coqui_available:
class CS_APITest(unittest.TestCase):
def test_speakers(self):
tts = CS_API()
@ -40,7 +41,6 @@ if is_coqui_available:
self.assertEqual(sr, 44100)
self.assertGreater(len(wav), 1)
class TTSTest(unittest.TestCase):
def test_single_speaker_model(self):
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
@ -86,7 +86,9 @@ if is_coqui_available:
def test_multi_speaker_multi_lingual_model(self):
tts = TTS()
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH)
tts.tts_to_file(
text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH
)
self.assertTrue(tts.is_multi_speaker)
self.assertTrue(tts.is_multi_lingual)