mirror of https://github.com/coqui-ai/TTS.git
XTTS v2.0 (#3137)
* Implement most similar ref training approach * Use non-enhanced hifigan for test samples * Add Perceiver * Update GPT Trainer for perceiver support * Update XTTS docs * Bug fix masking with XTTS perceiver * Bug fix on gpt forward * Bug Fix on XTTS v2.0 training * Add XTTS v2.0 unit tests * Add XTTS v2.0 inference unit tests * Bug Fix on diffusion inference * Add XTTS v2.0 training recipe * Placeholder model entry * Add cloning params to config * Make prompt embedding configurable * Make cloning configurable * Cheap fix for a cheaper fix * Prevent resampling * Update model entry * Update docs * Update requirements * Code linting * Add xtts v2 to sep tests * Bug fix on XTTS get_gpt_cond_latents * Bug fix on rebase * Make style * Bug fix in Japenese tokenizer * Add num2words to deps * Remove unused kwarg and added num_beams=1 as default --------- Co-authored-by: Eren G??lge <egolge@coqui.ai>pull/3120/head
parent
38f6f8f0bb
commit
e45227d9ff
|
@ -169,3 +169,4 @@ wandb
|
|||
depot/*
|
||||
coqui_recipes/*
|
||||
local_scripts/*
|
||||
coqui_demos/*
|
|
@ -2,6 +2,20 @@
|
|||
"tts_models": {
|
||||
"multilingual": {
|
||||
"multi-dataset": {
|
||||
"xtts_v2": {
|
||||
"description": "XTTS-v2 by Coqui with 16 languages.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
|
||||
],
|
||||
"default_vocoder": null,
|
||||
"commit": "480a6cdf7",
|
||||
"license": "CPML",
|
||||
"contact": "info@coqui.ai",
|
||||
"tos_required": true
|
||||
},
|
||||
"xtts_v1": {
|
||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
||||
"hf_url": [
|
||||
|
|
|
@ -59,6 +59,16 @@ class XttsConfig(BaseTTSConfig):
|
|||
|
||||
decoder_sampler (str):
|
||||
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
|
||||
|
||||
gpt_cond_len (int):
|
||||
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
|
||||
|
||||
max_ref_len (int):
|
||||
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
|
||||
|
||||
sound_norm_refs (bool):
|
||||
Whether to normalize the conditioning audio. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
|
@ -74,7 +84,24 @@ class XttsConfig(BaseTTSConfig):
|
|||
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
|
||||
model_dir: str = None
|
||||
languages: List[str] = field(
|
||||
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"]
|
||||
default_factory=lambda: [
|
||||
"en",
|
||||
"es",
|
||||
"fr",
|
||||
"de",
|
||||
"it",
|
||||
"pt",
|
||||
"pl",
|
||||
"tr",
|
||||
"ru",
|
||||
"nl",
|
||||
"cs",
|
||||
"ar",
|
||||
"zh-cn",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
)
|
||||
|
||||
# inference params
|
||||
|
@ -88,3 +115,8 @@ class XttsConfig(BaseTTSConfig):
|
|||
num_gpt_outputs: int = 1
|
||||
decoder_iterations: int = 30
|
||||
decoder_sampler: str = "ddim"
|
||||
|
||||
# cloning
|
||||
gpt_cond_len: int = 3
|
||||
max_ref_len: int = 10
|
||||
sound_norm_refs: bool = False
|
||||
|
|
|
@ -562,21 +562,15 @@ class DPM_Solver:
|
|||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 2
|
||||
) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
else:
|
||||
orders = [
|
||||
3,
|
||||
] * (
|
||||
orders = [3,] * (
|
||||
K - 1
|
||||
) + [2]
|
||||
elif order == 2:
|
||||
|
@ -587,9 +581,7 @@ class DPM_Solver:
|
|||
] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [
|
||||
2,
|
||||
] * (
|
||||
orders = [2,] * (
|
||||
K - 1
|
||||
) + [1]
|
||||
elif order == 1:
|
||||
|
@ -1448,10 +1440,7 @@ class DPM_Solver:
|
|||
model_prev_list[-1] = self.model_fn(x, t)
|
||||
elif method in ["singlestep", "singlestep_fixed"]:
|
||||
if method == "singlestep":
|
||||
(
|
||||
timesteps_outer,
|
||||
orders,
|
||||
) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
steps=steps,
|
||||
order=order,
|
||||
skip_type=skip_type,
|
||||
|
|
|
@ -11,6 +11,7 @@ from transformers import GPT2Config
|
|||
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
|
@ -105,6 +106,8 @@ class GPT(nn.Module):
|
|||
checkpointing=False,
|
||||
average_conditioning_embeddings=False,
|
||||
label_smoothing=0.0,
|
||||
use_perceiver_resampler=False,
|
||||
perceiver_cond_length_compression=256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -132,13 +135,12 @@ class GPT(nn.Module):
|
|||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_dropout = nn.Dropout1d(0.1)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
self.use_perceiver_resampler = use_perceiver_resampler
|
||||
self.perceiver_cond_length_compression = perceiver_cond_length_compression
|
||||
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
|
||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||
|
||||
(
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
|
@ -165,9 +167,29 @@ class GPT(nn.Module):
|
|||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||
|
||||
if self.use_perceiver_resampler:
|
||||
# XTTS v2
|
||||
self.conditioning_perceiver = PerceiverResampler(
|
||||
dim=model_dim,
|
||||
depth=2,
|
||||
dim_context=model_dim,
|
||||
num_latents=32,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
use_flash_attn=False,
|
||||
)
|
||||
else:
|
||||
# XTTS v1
|
||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"conditioning_perceiver": list(self.conditioning_perceiver.parameters())
|
||||
if self.use_perceiver_resampler
|
||||
else None,
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
@ -250,11 +272,8 @@ class GPT(nn.Module):
|
|||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
if attn_mask_cond is not None:
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
else:
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
|
@ -318,7 +337,6 @@ class GPT(nn.Module):
|
|||
prompt_len = 3
|
||||
prompt_len = prompt_len * 24 # in frames
|
||||
if prompt_codes.shape[-1] >= prompt_len:
|
||||
new_prompt = []
|
||||
for i in range(prompt_codes.shape[0]):
|
||||
if lengths[i] < prompt_len:
|
||||
start = 0
|
||||
|
@ -340,7 +358,9 @@ class GPT(nn.Module):
|
|||
if not return_latent:
|
||||
if cond_input.ndim == 4:
|
||||
cond_input = cond_input.squeeze(1)
|
||||
conds = self.conditioning_encoder(cond_input)
|
||||
conds = self.conditioning_encoder(cond_input) # (b, d, s)
|
||||
if self.use_perceiver_resampler:
|
||||
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
|
||||
else:
|
||||
# already computed
|
||||
conds = cond_input.unsqueeze(1)
|
||||
|
@ -354,6 +374,7 @@ class GPT(nn.Module):
|
|||
wav_lengths,
|
||||
cond_mels=None,
|
||||
cond_idxs=None,
|
||||
cond_lens=None,
|
||||
cond_latents=None,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
|
@ -379,10 +400,24 @@ class GPT(nn.Module):
|
|||
max_text_len = text_lengths.max()
|
||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||
|
||||
if cond_lens is not None:
|
||||
if self.use_perceiver_resampler:
|
||||
cond_lens = cond_lens // self.perceiver_cond_length_compression
|
||||
else:
|
||||
cond_lens = cond_lens // self.code_stride_len
|
||||
|
||||
if cond_idxs is not None:
|
||||
# recompute cond idxs for mel lengths
|
||||
for idx, l in enumerate(code_lengths):
|
||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
||||
for idx in range(cond_idxs.size(0)):
|
||||
if self.use_perceiver_resampler:
|
||||
cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression
|
||||
else:
|
||||
cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len
|
||||
|
||||
# ensure that the cond_mel does not have padding
|
||||
# if cond_lens is not None and cond_idxs is None:
|
||||
# min_cond_len = torch.min(cond_lens)
|
||||
# cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
@ -450,9 +485,13 @@ class GPT(nn.Module):
|
|||
)
|
||||
|
||||
if cond_idxs is not None:
|
||||
# use masking approach
|
||||
for idx, r in enumerate(cond_idxs):
|
||||
l = r[1] - r[0]
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
elif cond_lens is not None:
|
||||
for idx, l in enumerate(cond_lens):
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
@ -523,7 +562,7 @@ class GPT(nn.Module):
|
|||
|
||||
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
||||
self.compute_embeddings(cond_latents, text_inputs)
|
||||
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs)
|
||||
return self.generate(cond_latents, text_inputs, **hf_generate_kwargs)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from packaging import version
|
||||
from torch import einsum, nn
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def once(fn):
|
||||
called = False
|
||||
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
nonlocal called
|
||||
if called:
|
||||
return
|
||||
called = True
|
||||
return fn(x)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
print_once = once(print)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.causal = causal
|
||||
self.register_buffer("mask", None, persistent=False)
|
||||
|
||||
self.use_flash = use_flash
|
||||
assert not (
|
||||
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
||||
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
||||
|
||||
# determine efficient attention configs for cuda and cpu
|
||||
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
||||
self.cpu_config = self.config(True, True, True)
|
||||
self.cuda_config = None
|
||||
|
||||
if not torch.cuda.is_available() or not use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
|
||||
self.cuda_config = self.config(True, False, False)
|
||||
else:
|
||||
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
|
||||
self.cuda_config = self.config(False, True, True)
|
||||
|
||||
def get_mask(self, n, device):
|
||||
if exists(self.mask) and self.mask.shape[-1] >= n:
|
||||
return self.mask[:n, :n]
|
||||
|
||||
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
return mask
|
||||
|
||||
def flash_attn(self, q, k, v, mask=None):
|
||||
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
||||
|
||||
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||
|
||||
if k.ndim == 3:
|
||||
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
||||
|
||||
if v.ndim == 3:
|
||||
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
||||
|
||||
# Check if mask exists and expand to compatible shape
|
||||
# The mask is B L, so it would have to be expanded to B H N L
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b j -> b 1 1 j")
|
||||
mask = mask.expand(-1, heads, q_len, -1)
|
||||
|
||||
# Check if there is a compatible device for flash attention
|
||||
|
||||
config = self.cuda_config if is_cuda else self.cpu_config
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v, mask=mask)
|
||||
|
||||
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
||||
|
||||
# key padding mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b j -> b 1 1 j")
|
||||
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# causal mask
|
||||
|
||||
if self.causal:
|
||||
causal_mask = self.get_mask(n, device)
|
||||
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def Sequential(*mods):
|
||||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, scale=True, dim_cond=None):
|
||||
super().__init__()
|
||||
self.cond = exists(dim_cond)
|
||||
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
||||
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
gamma = default(self.gamma, 1)
|
||||
out = F.normalize(x, dim=-1) * self.scale * gamma
|
||||
|
||||
if not self.cond:
|
||||
return out
|
||||
|
||||
assert exists(cond)
|
||||
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
||||
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
||||
return out * gamma + beta
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
(kernel_size,) = self.kernel_size
|
||||
(dilation,) = self.dilation
|
||||
(stride,) = self.stride
|
||||
|
||||
assert stride == 1
|
||||
self.causal_padding = dilation * (kernel_size - 1)
|
||||
|
||||
def forward(self, x):
|
||||
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
return super().forward(causal_padded_x)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.gelu(gate) * x
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, causal_conv=False):
|
||||
dim_inner = int(dim * mult * 2 / 3)
|
||||
|
||||
conv = None
|
||||
if causal_conv:
|
||||
conv = nn.Sequential(
|
||||
Rearrange("b n d -> b d n"),
|
||||
CausalConv1d(dim_inner, dim_inner, 3),
|
||||
Rearrange("b d n -> b n d"),
|
||||
)
|
||||
|
||||
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
|
||||
|
||||
|
||||
class PerceiverResampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=2,
|
||||
dim_context=None,
|
||||
num_latents=32,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
use_flash_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
dim_context = default(dim_context, dim)
|
||||
|
||||
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||
nn.init.normal_(self.latents, std=0.02)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
use_flash=use_flash_attn,
|
||||
cross_attn_include_queries=True,
|
||||
),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
batch = x.shape[0]
|
||||
|
||||
x = self.proj_context(x)
|
||||
|
||||
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(latents, x, mask=mask) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
return self.norm(latents)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_context=None,
|
||||
causal=False,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
dropout=0.0,
|
||||
use_flash=False,
|
||||
cross_attn_include_queries=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.cross_attn_include_queries = cross_attn_include_queries
|
||||
|
||||
dim_inner = dim_head * heads
|
||||
dim_context = default(dim_context, dim)
|
||||
|
||||
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
||||
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
||||
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h, has_context = self.heads, exists(context)
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
if has_context and self.cross_attn_include_queries:
|
||||
context = torch.cat((x, context), dim=-2)
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = self.attend(q, k, v, mask=mask)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
|
@ -4,6 +4,8 @@ import re
|
|||
|
||||
import pypinyin
|
||||
import torch
|
||||
from hangul_romanize import Transliter
|
||||
from hangul_romanize.rule import academic
|
||||
from num2words import num2words
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
|
@ -112,7 +114,7 @@ _abbreviations = {
|
|||
# There are not many common abbreviations in Arabic as in English.
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
"zh": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
|
@ -155,6 +157,21 @@ _abbreviations = {
|
|||
# Add other Turkish abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
"hu": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("dr", "doktor"), # doctor
|
||||
("b", "bácsi"), # Mr.
|
||||
("nőv", "nővér"), # nurse
|
||||
# Add other Hungarian abbreviations here if needed.
|
||||
]
|
||||
],
|
||||
"ko": [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
@ -262,7 +279,7 @@ _symbols_multilingual = {
|
|||
("°", " درجة "),
|
||||
]
|
||||
],
|
||||
"zh-cn": [
|
||||
"zh": [
|
||||
# Chinese
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
|
@ -326,6 +343,31 @@ _symbols_multilingual = {
|
|||
("°", " derece "),
|
||||
]
|
||||
],
|
||||
"hu": [
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " és "),
|
||||
("@", " kukac "),
|
||||
("%", " százalék "),
|
||||
("#", " kettőskereszt "),
|
||||
("$", " dollár "),
|
||||
("£", " font "),
|
||||
("°", " fok "),
|
||||
]
|
||||
],
|
||||
"ko": [
|
||||
# Korean
|
||||
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("&", " 그리고 "),
|
||||
("@", " 에 "),
|
||||
("%", " 퍼센트 "),
|
||||
("#", " 번호 "),
|
||||
("$", " 달러 "),
|
||||
("£", " 파운드 "),
|
||||
("°", " 도 "),
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
@ -349,6 +391,8 @@ _ordinal_re = {
|
|||
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
||||
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
||||
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
||||
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
|
||||
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
|
||||
}
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
_currency_re = {
|
||||
|
@ -398,6 +442,8 @@ def _expand_currency(m, lang="en", currency="USD"):
|
|||
"nl": ", ",
|
||||
"ar": ", ",
|
||||
"tr": ", ",
|
||||
"hu": ", ",
|
||||
"ko": ", ",
|
||||
}
|
||||
|
||||
if amount.is_integer():
|
||||
|
@ -417,7 +463,7 @@ def _expand_number(m, lang="en"):
|
|||
|
||||
|
||||
def expand_numbers_multilingual(text, lang="en"):
|
||||
if lang == "zh-cn":
|
||||
if lang == "zh" or lang == "zh-cn":
|
||||
text = zh_num2words()(text)
|
||||
else:
|
||||
if lang in ["en", "ru"]:
|
||||
|
@ -468,7 +514,7 @@ def basic_cleaners(text):
|
|||
|
||||
def chinese_transliterate(text):
|
||||
return "".join(
|
||||
p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
|
||||
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
||||
)
|
||||
|
||||
|
||||
|
@ -478,42 +524,23 @@ def japanese_cleaners(text, katsu):
|
|||
return text
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
r = Transliter(academic)
|
||||
return r.translit(text)
|
||||
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json")
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, preprocess=None):
|
||||
def __init__(self, vocab_file=None):
|
||||
self.tokenizer = None
|
||||
self.katsu = None
|
||||
|
||||
if vocab_file is not None:
|
||||
with open(vocab_file, "r", encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
self.language = vocab["model"]["language"] if "language" in vocab["model"] else None
|
||||
|
||||
if preprocess is None:
|
||||
self.preprocess = "pre_tokenizer" in vocab and vocab["pre_tokenizer"]
|
||||
else:
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
if lang == "zh-cn":
|
||||
txt = chinese_transliterate(txt)
|
||||
elif lang == "ja":
|
||||
if self.katsu is None:
|
||||
import cutlet
|
||||
|
||||
self.katsu = cutlet.Cutlet()
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
if self.preprocess:
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
txt = f"[{lang}]{txt}"
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
@ -527,8 +554,200 @@ class VoiceBpeTokenizer:
|
|||
txt = txt.replace("[UNK]", "")
|
||||
return txt
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "zh", "ar", "cs", "ru", "nl", "tr", "hu"]:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
elif lang == "ja":
|
||||
if self.katsu is None:
|
||||
import cutlet
|
||||
|
||||
self.katsu = cutlet.Cutlet()
|
||||
txt = japanese_cleaners(txt, self.katsu)
|
||||
elif lang == "zh-cn" or lang == "zh":
|
||||
txt = chinese_transliterate(txt)
|
||||
elif lang == "ko":
|
||||
txt = korean_cleaners(txt)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return txt
|
||||
|
||||
def __len__(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
def get_number_tokens(self):
|
||||
return max(self.tokenizer.get_vocab().values()) + 1
|
||||
|
||||
|
||||
def test_expand_numbers_multilingual():
|
||||
test_cases = [
|
||||
# English
|
||||
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
|
||||
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
|
||||
("This is a 1st test", "This is a first test", "en"),
|
||||
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
|
||||
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
|
||||
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
|
||||
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
|
||||
# French
|
||||
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
|
||||
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
|
||||
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
|
||||
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
|
||||
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
|
||||
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
|
||||
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
|
||||
# German
|
||||
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
|
||||
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
|
||||
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
|
||||
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
|
||||
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
|
||||
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
|
||||
# Spanish
|
||||
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
|
||||
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
|
||||
("Este es un 1er test", "Este es un primero test", "es"),
|
||||
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
|
||||
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
|
||||
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
|
||||
# Italian
|
||||
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
|
||||
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
|
||||
("Questo è un 1° test", "Questo è un primo test", "it"),
|
||||
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
|
||||
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
|
||||
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
|
||||
# Portuguese
|
||||
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
|
||||
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
|
||||
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
|
||||
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
|
||||
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
|
||||
(
|
||||
"Isso custará 20,15€ senhor.",
|
||||
"Isso custará vinte euros e quinze cêntimos senhor.",
|
||||
"pt",
|
||||
), # "cêntimos" should be "centavos" num2words issue
|
||||
# Polish
|
||||
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
|
||||
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
|
||||
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
|
||||
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
|
||||
# Arabic
|
||||
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
|
||||
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
|
||||
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
|
||||
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
|
||||
# Czech
|
||||
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
|
||||
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
|
||||
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
|
||||
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
|
||||
# Russian
|
||||
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
|
||||
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
|
||||
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
|
||||
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
|
||||
# Dutch
|
||||
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
|
||||
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
|
||||
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
||||
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
||||
# Chinese (Simplified)
|
||||
("在12.5秒内", "在十二点五秒内", "zh"),
|
||||
("有50名士兵", "有五十名士兵", "zh"),
|
||||
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
||||
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
||||
# Turkish
|
||||
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
|
||||
("50 asker vardı.", "elli asker vardı.", "tr"),
|
||||
("Bu 1. test", "Bu birinci test", "tr"),
|
||||
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
|
||||
# Hungarian
|
||||
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
|
||||
("50 katona volt.", "ötven katona volt.", "hu"),
|
||||
("Ez az 1. teszt", "Ez az első teszt", "hu"),
|
||||
# Korean
|
||||
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
|
||||
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
|
||||
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
|
||||
]
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_numbers_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
def test_abbreviations_multilingual():
|
||||
test_cases = [
|
||||
# English
|
||||
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
|
||||
("Dr. Jones is here.", "doctor Jones is here.", "en"),
|
||||
# Spanish
|
||||
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
|
||||
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
|
||||
# French
|
||||
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
|
||||
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
|
||||
# German
|
||||
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
|
||||
# Portuguese
|
||||
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
|
||||
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
|
||||
# Italian
|
||||
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
|
||||
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
|
||||
# Polish
|
||||
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
|
||||
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
|
||||
# Czech
|
||||
("P. Novák", "pan Novák", "cs"),
|
||||
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
|
||||
# Dutch
|
||||
("Dhr. Jansen", "de heer Jansen", "nl"),
|
||||
("Mevr. de Vries", "mevrouw de Vries", "nl"),
|
||||
# Russian
|
||||
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
|
||||
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
|
||||
# Turkish
|
||||
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
|
||||
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
|
||||
# Hungarian
|
||||
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
|
||||
]
|
||||
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_abbreviations_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
def test_symbols_multilingual():
|
||||
test_cases = [
|
||||
("I have 14% battery", "I have 14 percent battery", "en"),
|
||||
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
|
||||
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
|
||||
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
|
||||
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
|
||||
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
|
||||
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
|
||||
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
|
||||
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
|
||||
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
|
||||
("Я буду @ дома", "Я буду собака дома", "ru"),
|
||||
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
||||
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
||||
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
||||
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
||||
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
||||
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
||||
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
||||
]
|
||||
|
||||
for a, b, lang in test_cases:
|
||||
out = expand_symbols_multilingual(a, lang=lang)
|
||||
assert out == b, f"'{out}' vs '{b}'"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_expand_numbers_multilingual()
|
||||
test_abbreviations_multilingual()
|
||||
test_symbols_multilingual()
|
||||
|
|
|
@ -88,6 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
self.sample_rate = sample_rate
|
||||
self.max_wav_len = model_args.max_wav_length
|
||||
self.max_text_len = model_args.max_text_length
|
||||
self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
|
||||
self.samples = samples
|
||||
|
@ -109,7 +110,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
try:
|
||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||
except:
|
||||
pass
|
||||
continue
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
if (
|
||||
wav is None
|
||||
|
@ -140,10 +141,24 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
# Ultra short clips are also useless (and can cause problems within some models).
|
||||
raise ValueError
|
||||
|
||||
# get a slice from GT to condition the model
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||
)
|
||||
if self.use_masking_gt_prompt_approach:
|
||||
# get a slice from GT to condition the model
|
||||
cond, _, cond_idxs = get_prompt_slice(
|
||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||
)
|
||||
# if use masking do not use cond_len
|
||||
cond_len = torch.nan
|
||||
else:
|
||||
ref_sample = (
|
||||
sample["reference_path"]
|
||||
if "reference_path" in sample and sample["reference_path"] is not None
|
||||
else audiopath
|
||||
)
|
||||
cond, cond_len, _ = get_prompt_slice(
|
||||
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||
)
|
||||
# if do not use masking use cond_len
|
||||
cond_idxs = torch.nan
|
||||
|
||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
|
@ -199,8 +214,10 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
"filenames": audiopath,
|
||||
"conditioning": cond.unsqueeze(1),
|
||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
||||
"cond_idxs": torch.tensor(cond_idxs),
|
||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long)
|
||||
if cond_len is not torch.nan
|
||||
else torch.tensor([cond_len]),
|
||||
"cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]),
|
||||
}
|
||||
return res
|
||||
|
||||
|
@ -221,6 +238,13 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||
|
||||
if torch.any(batch["cond_idxs"].isnan()):
|
||||
batch["cond_idxs"] = None
|
||||
|
||||
if torch.any(batch["cond_lens"].isnan()):
|
||||
batch["cond_lens"] = None
|
||||
|
||||
max_text_len = batch["text_lengths"].max()
|
||||
max_wav_len = batch["wav_lengths"].max()
|
||||
|
||||
|
|
|
@ -141,17 +141,30 @@ class GPTTrainer(BaseTTS):
|
|||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=2048,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
else:
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
|
||||
# Load DVAE
|
||||
self.dvae = DiscreteVAE(
|
||||
|
@ -186,7 +199,7 @@ class GPTTrainer(BaseTTS):
|
|||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
@ -197,9 +210,16 @@ class GPTTrainer(BaseTTS):
|
|||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
cond_lens: long tensor, (b,)
|
||||
"""
|
||||
losses = self.xtts.gpt(
|
||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
||||
text_inputs,
|
||||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
cond_mels=cond_mels,
|
||||
cond_idxs=cond_idxs,
|
||||
cond_lens=cond_lens,
|
||||
)
|
||||
return losses
|
||||
|
||||
|
@ -213,7 +233,12 @@ class GPTTrainer(BaseTTS):
|
|||
print(" | > Synthesizing test sentences.")
|
||||
for idx, s_info in enumerate(self.config.test_sentences):
|
||||
wav = self.xtts.synthesize(
|
||||
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
|
||||
s_info["text"],
|
||||
self.config,
|
||||
s_info["speaker_wav"],
|
||||
s_info["language"],
|
||||
gpt_cond_len=3,
|
||||
decoder="ne_hifigan",
|
||||
)["wav"]
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
|
||||
|
@ -269,7 +294,6 @@ class GPTTrainer(BaseTTS):
|
|||
del batch["padded_text"]
|
||||
del batch["wav"]
|
||||
del batch["conditioning"]
|
||||
del batch["cond_lens"]
|
||||
return batch
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
|
@ -280,8 +304,11 @@ class GPTTrainer(BaseTTS):
|
|||
audio_codes = batch["audio_codes"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
cond_idxs = batch["cond_idxs"]
|
||||
cond_lens = batch["cond_lens"]
|
||||
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
||||
loss_text, loss_mel, _ = self.forward(
|
||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens
|
||||
)
|
||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -252,12 +252,7 @@ class BaseTacotron(BaseTTS):
|
|||
|
||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||
"""Capacitron Variational Autoencoder"""
|
||||
(
|
||||
VAE_outputs,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
) = self.capacitron_vae_layer(
|
||||
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
||||
reference_mel_info,
|
||||
text_info,
|
||||
speaker_embedding, # pylint: disable=not-callable
|
||||
|
|
|
@ -676,12 +676,7 @@ class Tortoise(BaseTTS):
|
|||
), "Too much text provided. Break the text up into separate segments and re-try inference."
|
||||
|
||||
if voice_samples is not None:
|
||||
(
|
||||
auto_conditioning,
|
||||
diffusion_conditioning,
|
||||
_,
|
||||
_,
|
||||
) = self.get_conditioning_latents(
|
||||
(auto_conditioning, diffusion_conditioning, _, _,) = self.get_conditioning_latents(
|
||||
voice_samples,
|
||||
return_mels=True,
|
||||
latent_averaging_mode=latent_averaging_mode,
|
||||
|
|
|
@ -23,7 +23,19 @@ init_stream_support()
|
|||
|
||||
|
||||
def wav_to_mel_cloning(
|
||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
||||
wav,
|
||||
mel_norms_file="../experiments/clips_mel_norms.pth",
|
||||
mel_norms=None,
|
||||
device=torch.device("cpu"),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
):
|
||||
"""
|
||||
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
||||
|
@ -38,15 +50,15 @@ def wav_to_mel_cloning(
|
|||
torch.Tensor: Mel-spectrogram tensor.
|
||||
"""
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
power=power,
|
||||
normalized=normalized,
|
||||
sample_rate=sample_rate,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
wav = wav.to(device)
|
||||
|
@ -177,19 +189,23 @@ class XttsArgs(Coqpit):
|
|||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
||||
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
|
||||
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
|
||||
|
||||
For GPT model:
|
||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
||||
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
||||
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
||||
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
||||
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
||||
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
||||
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
||||
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
||||
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
||||
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
||||
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
||||
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
||||
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
||||
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
||||
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
||||
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
||||
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
||||
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
|
||||
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
|
||||
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
|
||||
|
||||
For DiffTTS model:
|
||||
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
|
||||
|
@ -229,6 +245,9 @@ class XttsArgs(Coqpit):
|
|||
gpt_num_audio_tokens: int = 8194
|
||||
gpt_start_audio_token: int = 8192
|
||||
gpt_stop_audio_token: int = 8193
|
||||
gpt_code_stride_len: int = 1024
|
||||
gpt_use_masking_gt_prompt_approach: bool = True
|
||||
gpt_use_perceiver_resampler: bool = False
|
||||
|
||||
# Diffusion Decoder params
|
||||
diff_model_channels: int = 1024
|
||||
|
@ -247,7 +266,6 @@ class XttsArgs(Coqpit):
|
|||
input_sample_rate: int = 22050
|
||||
output_sample_rate: int = 24000
|
||||
output_hop_length: int = 256
|
||||
ar_mel_length_compression: int = 1024
|
||||
decoder_input_dim: int = 1024
|
||||
d_vector_dim: int = 512
|
||||
cond_d_vector_in_each_upsampling_layer: bool = True
|
||||
|
@ -304,6 +322,8 @@ class Xtts(BaseTTS):
|
|||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||
start_audio_token=self.args.gpt_start_audio_token,
|
||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
||||
code_stride_len=self.args.gpt_code_stride_len,
|
||||
)
|
||||
|
||||
if self.args.use_hifigan:
|
||||
|
@ -311,7 +331,7 @@ class Xtts(BaseTTS):
|
|||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
|
@ -322,7 +342,7 @@ class Xtts(BaseTTS):
|
|||
input_sample_rate=self.args.input_sample_rate,
|
||||
output_sample_rate=self.args.output_sample_rate,
|
||||
output_hop_length=self.args.output_hop_length,
|
||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
||||
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||
decoder_input_dim=self.args.decoder_input_dim,
|
||||
d_vector_dim=self.args.d_vector_dim,
|
||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||
|
@ -354,12 +374,33 @@ class Xtts(BaseTTS):
|
|||
|
||||
Args:
|
||||
audio_path (str): Path to the audio file.
|
||||
sr (int): Sample rate of the audio.
|
||||
length (int): Length of the audio in seconds. Defaults to 3.
|
||||
"""
|
||||
|
||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio_22k = audio_22k[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
||||
if sr != 22050:
|
||||
audio = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio = audio[:, : 22050 * length]
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
else:
|
||||
n_fft = 4096
|
||||
hop_length = 1024
|
||||
win_length = 4096
|
||||
mel = wav_to_mel_cloning(
|
||||
audio,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
|
@ -461,6 +502,9 @@ class Xtts(BaseTTS):
|
|||
"diffusion_temperature": config.diffusion_temperature,
|
||||
"decoder_iterations": config.decoder_iterations,
|
||||
"decoder_sampler": config.decoder_sampler,
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
|
@ -477,8 +521,11 @@ class Xtts(BaseTTS):
|
|||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
gpt_cond_len=6,
|
||||
do_sample=True,
|
||||
# Cloning
|
||||
gpt_cond_len=6,
|
||||
max_ref_len=10,
|
||||
sound_norm_refs=False,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
|
@ -546,8 +593,12 @@ class Xtts(BaseTTS):
|
|||
Sample rate is 24kHz.
|
||||
"""
|
||||
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
||||
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
|
||||
audio_path=ref_audio_path,
|
||||
gpt_cond_len=gpt_cond_len,
|
||||
max_ref_length=max_ref_len,
|
||||
sound_norm_refs=sound_norm_refs,
|
||||
)
|
||||
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
|
@ -591,11 +642,16 @@ class Xtts(BaseTTS):
|
|||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
decoder="hifigan",
|
||||
num_beams=1,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text = text.strip().lower()
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
# print(" > Input text: ", text)
|
||||
# print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language))
|
||||
# print(" > Input tokens: ", text_tokens)
|
||||
# print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy()))
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
@ -618,6 +674,7 @@ class Xtts(BaseTTS):
|
|||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
num_beams=num_beams,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
|
@ -671,7 +728,12 @@ class Xtts(BaseTTS):
|
|||
)
|
||||
wav = self.vocoder.inference(mel)
|
||||
|
||||
return {"wav": wav.cpu().numpy().squeeze()}
|
||||
return {
|
||||
"wav": wav.cpu().numpy().squeeze(),
|
||||
"gpt_latents": gpt_latents,
|
||||
"speaker_embedding": speaker_embedding,
|
||||
"diffusion_conditioning": diffusion_conditioning,
|
||||
}
|
||||
|
||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||
"""Handle chunk formatting in streaming mode"""
|
||||
|
|
|
@ -392,7 +392,7 @@ class ModelManager(object):
|
|||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
if "xtts_v1" in model_name:
|
||||
if "xtts" in model_name:
|
||||
try:
|
||||
self.check_if_configs_are_equal(model_name, model_item, output_path)
|
||||
except:
|
||||
|
@ -406,7 +406,7 @@ class ModelManager(object):
|
|||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
|
||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
|
|
|
@ -7,17 +7,25 @@ This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui
|
|||
a few tricks to make it faster and support streaming inference.
|
||||
|
||||
### Features
|
||||
- Voice cloning with just a 3-second audio clip.
|
||||
- Voice cloning.
|
||||
- Cross-language voice cloning.
|
||||
- Multi-lingual speech generation.
|
||||
- 24khz sampling rate.
|
||||
- Streaming inference with < 200ms latency. (See [Streaming inference](#streaming-inference))
|
||||
- Fine-tuning support. (See [Training](#training))
|
||||
|
||||
### Updates with v2
|
||||
- Improved voice cloning.
|
||||
- Voices can be cloned with a single audio file or multiple audio files, without any effect on the runtime.
|
||||
- 2 new languages: Hungarian and Korean.
|
||||
- Across the board quality improvements.
|
||||
|
||||
### Code
|
||||
Current implementation only supports inference.
|
||||
|
||||
### Languages
|
||||
As of now, XTTS-v1.1 supports 14 languages: English, Spanish, French, German, Italian, Portuguese,
|
||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified) and Japanese.
|
||||
As of now, XTTS-v2 supports 16 languages: English, Spanish, French, German, Italian, Portuguese,
|
||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified), Japanese, Hungarian, Korean
|
||||
|
||||
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
|
||||
|
||||
|
@ -33,7 +41,7 @@ You can also mail us at info@coqui.ai.
|
|||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1.1", gpu=True)
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
|
||||
|
||||
# generate speech by cloning a voice using default settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
|
@ -45,7 +53,7 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
|
|||
#### 🐸TTS Command line
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 \
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
--speaker_wav /path/to/target/speaker.wav \
|
||||
--language_idx tr \
|
||||
|
@ -73,7 +81,7 @@ config.load_json("/path/to/xtts/config.json")
|
|||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
|
||||
model.cuda()
|
||||
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
|
||||
|
||||
|
@ -122,7 +130,7 @@ chunks = model.inference_stream(
|
|||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
)
|
||||
|
||||
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
|
@ -136,7 +144,7 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
|||
|
||||
### Training
|
||||
|
||||
A recipe for `XTTS_v1.1` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
|
||||
A recipe for `XTTS_v2` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
|
||||
|
||||
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
|
||||
|
||||
|
@ -152,7 +160,7 @@ from TTS.tts.models.xtts import Xtts
|
|||
# Add here the xtts_config path
|
||||
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
|
||||
# Add here the vocab file that you have used to train the model
|
||||
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v1.1_original_model_files/vocab.json"
|
||||
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v2_original_model_files/vocab.json"
|
||||
# Add here the checkpoint that you want to do inference with
|
||||
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
|
||||
# Add here the speaker reference
|
||||
|
@ -184,13 +192,14 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
|||
```
|
||||
|
||||
|
||||
## Important resources & papers
|
||||
## References and Acknowledgements
|
||||
- VallE: https://arxiv.org/abs/2301.02111
|
||||
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
|
||||
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
|
||||
- Univnet: https://arxiv.org/abs/2106.07889
|
||||
- Latent Diffusion:https://arxiv.org/abs/2112.10752
|
||||
- DALL-E: https://arxiv.org/abs/2102.12092
|
||||
- Perceiver: https://arxiv.org/abs/2103.03206
|
||||
|
||||
|
||||
## XttsConfig
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_v2.0_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
||||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||
OUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run", "training")
|
||||
|
||||
# Training Parameters
|
||||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||
START_WITH_EVAL = True # if True it will star with evaluation
|
||||
BATCH_SIZE = 3 # set here the batch size
|
||||
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
path="/raid/datasets/LJSpeech-1.1_24khz/",
|
||||
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
# Add here the configs of the datasets
|
||||
DATASETS_CONFIG_LIST = [config_dataset]
|
||||
|
||||
# Define the path where XTTS v2.0.1 files will be downloaded
|
||||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||
|
||||
# Set the path to the downloaded files
|
||||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1])
|
||||
|
||||
# download DVAE files if needed
|
||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
# ToDo: Update links for XTTS v2.0
|
||||
|
||||
# Download XTTS v2.0 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v2.0/model.pth"
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
|
||||
# download XTTS v2.0 files if needed
|
||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||
print(" > Downloading XTTS v2.0 files!")
|
||||
ModelManager._download_model_files(
|
||||
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||
)
|
||||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = (
|
||||
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
)
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
def main():
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file=MEL_NORM_FILE,
|
||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file=TOKENIZER_FILE,
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
||||
gpt_use_masking_gt_prompt_approach=True,
|
||||
gpt_use_perceiver_resampler=True,
|
||||
)
|
||||
# define audio config
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
)
|
||||
# training parameters config
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{
|
||||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(
|
||||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||
skip_train_epoch=False,
|
||||
start_with_eval=START_WITH_EVAL,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -33,6 +33,8 @@ coqpit>=0.0.16
|
|||
# chinese g2p deps
|
||||
jieba
|
||||
pypinyin
|
||||
# korean
|
||||
hangul_romanize
|
||||
# gruut+supported langs
|
||||
gruut[de,es,fr]==2.2.3
|
||||
# deps for korean
|
||||
|
@ -51,3 +53,4 @@ transformers==4.33.*
|
|||
encodec==0.1.*
|
||||
# deps for XTTS
|
||||
unidecode==1.3.*
|
||||
num2words
|
||||
|
|
|
@ -86,6 +86,7 @@ model_args = GPTArgs(
|
|||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
use_ne_hifigan=True,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
path="tests/data/ljspeech/",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_dataset]
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
||||
os.makedirs(OUT_PATH, exist_ok=True)
|
||||
|
||||
# Create DVAE checkpoint and mel_norms on test time
|
||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||
# Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
MEL_NORM_FILE = os.path.join(OUT_PATH, "mel_stats.pth")
|
||||
dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
positional_dims=1,
|
||||
num_tokens=8192,
|
||||
codebook_dim=512,
|
||||
hidden_dim=512,
|
||||
num_resnet_blocks=3,
|
||||
kernel_size=3,
|
||||
num_layers=2,
|
||||
use_transposed_convs=False,
|
||||
)
|
||||
torch.save(dvae.state_dict(), DVAE_CHECKPOINT)
|
||||
mel_stats = torch.ones(80)
|
||||
torch.save(mel_stats, MEL_NORM_FILE)
|
||||
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = "tests/inputs/xtts_vocab.json" # vocab.json file
|
||||
XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
|
||||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
# Training Parameters
|
||||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||
START_WITH_EVAL = False # if True it will star with evaluation
|
||||
BATCH_SIZE = 2 # set here the batch size
|
||||
GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
|
||||
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file=MEL_NORM_FILE,
|
||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file=TOKENIZER_FILE,
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
gpt_use_masking_gt_prompt_approach=True,
|
||||
gpt_use_perceiver_resampler=True,
|
||||
use_ne_hifigan=True,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
)
|
||||
config = GPTTrainerConfig(
|
||||
epochs=1,
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="GPT XTTS training",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(
|
||||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||
skip_train_epoch=False,
|
||||
start_with_eval=True,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
# remove output path
|
||||
shutil.rmtree(OUT_PATH)
|
|
@ -16,6 +16,7 @@ MODELS_WITH_SEP_TESTS = [
|
|||
"tts_models/en/multi-dataset/tortoise-v2",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v1",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v1.1",
|
||||
"tts_models/multilingual/multi-dataset/xtts_v2",
|
||||
]
|
||||
|
||||
|
||||
|
@ -126,6 +127,58 @@ def test_xtts_streaming():
|
|||
assert len(wav_chuncks) > 1
|
||||
|
||||
|
||||
def test_xtts_v2():
|
||||
"""XTTS is too big to run on github actions. We need to test it locally"""
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
use_gpu = torch.cuda.is_available()
|
||||
if use_gpu:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
else:
|
||||
run_cli(
|
||||
"yes | "
|
||||
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
|
||||
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
|
||||
)
|
||||
|
||||
|
||||
def test_xtts_v2_streaming():
|
||||
"""Testing the new inference_stream method"""
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v2")
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir=model_path)
|
||||
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
|
||||
|
||||
print("Inference...")
|
||||
chunks = model.inference_stream(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
assert chunk.shape[-1] > 5000
|
||||
wav_chuncks.append(chunk)
|
||||
assert len(wav_chuncks) > 1
|
||||
|
||||
|
||||
def test_tortoise():
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
use_gpu = torch.cuda.is_available()
|
||||
|
|
Loading…
Reference in New Issue