Add lang code in XTTS doc (#3158)

* Add lang code in XTTS doc

* Remove ununsed config and args

* update docs

* woops
pull/3173/head
Julian Weber 2023-11-08 13:47:33 +01:00 committed by GitHub
parent 78a596618a
commit 03ad90135b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 66 deletions

View File

@ -37,29 +37,11 @@ class XttsConfig(BaseTTSConfig):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
Defaults to `0.8`.
cond_free_k (float):
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
diffusion_temperature (float):
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
are the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to `1.0`.
num_gpt_outputs (int):
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
Defaults to `16`.
decoder_iterations (int):
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
however. Defaults to `30`.
decoder_sampler (str):
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
@ -110,11 +92,7 @@ class XttsConfig(BaseTTSConfig):
repetition_penalty: float = 2.0
top_k: int = 50
top_p: float = 0.85
cond_free_k: float = 2.0
diffusion_temperature: float = 1.0
num_gpt_outputs: int = 1
decoder_iterations: int = 30
decoder_sampler: str = "ddim"
# cloning
gpt_cond_len: int = 3

View File

@ -152,19 +152,6 @@ class XttsArgs(Coqpit):
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
For DiffTTS model:
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
"""
gpt_batch_size: int = 1
@ -193,19 +180,6 @@ class XttsArgs(Coqpit):
gpt_use_masking_gt_prompt_approach: bool = True
gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params
diff_model_channels: int = 1024
diff_num_layers: int = 10
diff_in_channels: int = 100
diff_out_channels: int = 200
diff_in_latent_channels: int = 1024
diff_in_tokens: int = 8193
diff_dropout: int = 0
diff_use_fp16: bool = False
diff_num_heads: int = 16
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# HifiGAN Decoder params
input_sample_rate: int = 22050
output_sample_rate: int = 24000
@ -426,10 +400,6 @@ class Xtts(BaseTTS):
"repetition_penalty": config.repetition_penalty,
"top_k": config.top_k,
"top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"decoder_iterations": config.decoder_iterations,
"decoder_sampler": config.decoder_sampler,
"gpt_cond_len": config.gpt_cond_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
@ -454,13 +424,6 @@ class Xtts(BaseTTS):
gpt_cond_len=6,
max_ref_len=10,
sound_norm_refs=False,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs,
):
"""

View File

@ -24,8 +24,7 @@ a few tricks to make it faster and support streaming inference.
Current implementation only supports inference.
### Languages
As of now, XTTS-v2 supports 16 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified), Japanese, Hungarian, Korean
As of now, XTTS-v2 supports 16 languages: English (en), Spanish (es), French (fr), German (de), Italian (it), Portuguese (pt), Polish (pl), Turkish (tr), Russian (ru), Dutch (nl), Czech (cs), Arabic (ar), Chinese (zh-cn), Japanese (ja), Hungarian (hu) and Korean (ko).
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
@ -116,7 +115,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...")
out = model.inference(
@ -124,7 +123,6 @@ out = model.inference(
"en",
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here
)
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
@ -153,7 +151,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...")
t0 = time.time()
@ -210,7 +208,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
print("Inference...")
out = model.inference(
@ -218,7 +216,6 @@ out = model.inference(
"en",
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here
)
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)