Merge branch 'coqui-ai:dev' into dev

pull/2735/head
Frederico S. Oliveira 2023-12-11 10:04:07 -03:00 committed by GitHub
commit 163f9a3fdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 839 additions and 58 deletions

View File

@ -1 +1 @@
0.21.2
0.21.3

View File

@ -0,0 +1,2 @@
faster_whisper==0.9.0
gradio==4.7.1

View File

@ -0,0 +1,160 @@
import os
import gc
import torchaudio
import pandas
from faster_whisper import WhisperModel
from glob import glob
from tqdm import tqdm
import torch
import torchaudio
# torch.set_num_threads(1)
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
torch.set_num_threads(16)
import os
audio_types = (".wav", ".mp3", ".flac")
def list_audios(basePath, contains=None):
# return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains)
def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath):
# loop over the filenames in the current directory
for filename in filenames:
# if the contains string is not none and the filename does not contain
# the supplied string, then ignore the file
if contains is not None and filename.find(contains) == -1:
continue
# determine the file extension of the current file
ext = filename[filename.rfind("."):].lower()
# check to see if the file is an audio and should be processed
if validExts is None or ext.endswith(validExts):
# construct the path to the audio and yield it
audioPath = os.path.join(rootDir, filename)
yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
audio_total_size = 0
# make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True)
# Loading Whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading Whisper Model!")
asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
metadata = {"audio_file": [], "text": [], "speaker_name": []}
if gradio_progress is not None:
tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
else:
tqdm_object = tqdm(audio_files)
for audio_path in tqdm_object:
wav, sr = torchaudio.load(audio_path)
# stereo to mono if needed
if wav.size(0) != 1:
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr)
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments)
i = 0
sentence = ""
sentence_start = None
first_word = True
# added all segments words in a unique list
words_list = []
for _, segment in enumerate(segments):
words = list(segment.words)
words_list.extend(words)
# process each word
for word_idx, word in enumerate(words_list):
if first_word:
sentence_start = word.start
# If it is the first sentence, add buffer or get the begining of the file
if word_idx == 0:
sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start
else:
# get previous sentence end
previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence midle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
sentence = word.word
first_word = False
else:
sentence += word.word
if word.word[-1] in ["!", ".", "?"]:
sentence = sentence[1:]
# Expand number and abbreviations plus normalization
sentence = multilingual_cleaners(sentence, target_language)
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
# Check for the next word's existence
if word_idx + 1 < len(words_list):
next_word_start = words_list[word_idx + 1].start
else:
# If don't have more words it means that it is the last sentence then use the audio len as next word start
next_word_start = (wav.shape[0] - 1) / sr
# Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
absoulte_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
i += 1
first_word = True
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path,
audio,
sr
)
else:
continue
metadata["audio_file"].append(audio_file)
metadata["text"].append(sentence)
metadata["speaker_name"].append(speaker_name)
df = pandas.DataFrame(metadata)
df = df.sample(frac=1)
num_val_samples = int(len(df)*eval_percentage)
df_eval = df[:num_val_samples]
df_train = df[num_val_samples:]
df_train = df_train.sort_values('audio_file')
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False)
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file')
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
# deallocate VRAM and RAM
del asr_model, df_train, df_eval, df, metadata
gc.collect()
return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -0,0 +1,172 @@
import os
import gc
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
# Logging parameters
RUN_NAME = "GPT_XTTS_FT"
PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./run/training/
OUT_PATH = os.path.join(output_path, "run", "training")
# Training Parameters
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
START_WITH_EVAL = False # if True it will star with evaluation
BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="coqui",
dataset_name="ft_dataset",
path=os.path.dirname(train_csv),
meta_file_train=train_csv,
meta_file_val=eval_csv,
language=language,
)
# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
# Define the path where XTTS v2.0.1 files will be downloaded
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
# download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
# Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
# download XTTS v2.0 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
print(" > Downloading XTTS v2.0 files!")
ModelManager._download_model_files(
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=max_audio_length, # ~11.6 seconds
max_text_length=200,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=1026,
gpt_start_audio_token=1024,
gpt_stop_audio_token=1025,
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)
# define audio config
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
# training parameters config
config = GPTTrainerConfig(
epochs=num_epochs,
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=100,
save_step=1000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
skip_train_epoch=False,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
# get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]
trainer_out_path = trainer.output_path
# deallocate VRAM and RAM
del model, trainer, train_samples, eval_samples
gc.collect()
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref

View File

@ -0,0 +1,415 @@
import argparse
import os
import sys
import tempfile
import gradio as gr
import librosa.display
import numpy as np
import os
import torch
import torchaudio
import traceback
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model! ")
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect
class Logger:
def __init__(self, filename="log.out"):
self.log_file = filename
self.terminal = sys.stdout
self.log = open(self.log_file, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
# redirect stdout and stderr to a file
sys.stdout = Logger()
sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
)
def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
return f.read()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n"""
"""
Example runs:
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--port",
type=int,
help="Port to run the gradio demo. Default: 5003",
default=5003,
)
parser.add_argument(
"--out_path",
type=str,
help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
default="/tmp/xtts_ft/",
)
parser.add_argument(
"--num_epochs",
type=int,
help="Number of epochs to train. Default: 10",
default=10,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size. Default: 4",
default=4,
)
parser.add_argument(
"--grad_acumm",
type=int,
help="Grad accumulation steps. Default: 1",
default=1,
)
parser.add_argument(
"--max_audio_length",
type=int,
help="Max permitted audio size in seconds. Default: 11",
default=11,
)
args = parser.parse_args()
with gr.Blocks() as demo:
with gr.Tab("1 - Data processing"):
out_path = gr.Textbox(
label="Output path (where data and checkpoints will be saved):",
value=args.out_path,
)
# upload_file = gr.Audio(
# sources="upload",
# label="Select here the audio files that you want to use for XTTS trainining !",
# type="filepath",
# )
upload_file = gr.File(
file_count="multiple",
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
)
lang = gr.Dropdown(
label="Dataset Language",
value="en",
choices=[
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh",
"hu",
"ko",
"ja"
],
)
progress_data = gr.Label(
label="Progress:"
)
logs = gr.Textbox(
label="Logs:",
interactive=False,
)
demo.load(read_logs, None, logs, every=1)
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache()
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is None:
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
else:
try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
clear_gpu_cache()
# if audio total len is less than 2 minutes raise an error
if audio_total_size < 120:
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
print(message)
return message, "", ""
print("Dataset Processed!")
return "Dataset Processed!", train_meta, eval_meta
with gr.Tab("2 - Fine-tuning XTTS Encoder"):
train_csv = gr.Textbox(
label="Train CSV:",
)
eval_csv = gr.Textbox(
label="Eval CSV:",
)
num_epochs = gr.Slider(
label="Number of epochs:",
minimum=1,
maximum=100,
step=1,
value=args.num_epochs,
)
batch_size = gr.Slider(
label="Batch size:",
minimum=2,
maximum=512,
step=1,
value=args.batch_size,
)
grad_acumm = gr.Slider(
label="Grad accumulation steps:",
minimum=2,
maximum=128,
step=1,
value=args.grad_acumm,
)
max_audio_length = gr.Slider(
label="Max permitted audio size in seconds:",
minimum=2,
maximum=20,
step=1,
value=args.max_audio_length,
)
progress_train = gr.Label(
label="Progress:"
)
logs_tts_train = gr.Textbox(
label="Logs:",
interactive=False,
)
demo.load(read_logs, None, logs_tts_train, every=1)
train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
clear_gpu_cache()
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
try:
# convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
# copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}")
os.system(f"cp {vocab_file} {exp_path}")
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
print("Model training done!")
clear_gpu_cache()
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
with gr.Tab("3 - Inference"):
with gr.Row():
with gr.Column() as col1:
xtts_checkpoint = gr.Textbox(
label="XTTS checkpoint path:",
value="",
)
xtts_config = gr.Textbox(
label="XTTS config path:",
value="",
)
xtts_vocab = gr.Textbox(
label="XTTS vocab path:",
value="",
)
progress_load = gr.Label(
label="Progress:"
)
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2:
speaker_reference_audio = gr.Textbox(
label="Speaker reference audio:",
value="",
)
tts_language = gr.Dropdown(
label="Language",
value="en",
choices=[
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh",
"hu",
"ko",
"ja",
]
)
tts_text = gr.Textbox(
label="Input Text.",
value="This model sounds really good and above all, it's reasonably fast.",
)
tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3:
progress_gen = gr.Label(
label="Progress:"
)
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
prompt_compute_btn.click(
fn=preprocess_dataset,
inputs=[
upload_file,
lang,
out_path,
],
outputs=[
progress_data,
train_csv,
eval_csv,
],
)
train_btn.click(
fn=train_model,
inputs=[
lang,
train_csv,
eval_csv,
num_epochs,
batch_size,
grad_acumm,
out_path,
max_audio_length,
],
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
)
load_btn.click(
fn=load_model,
inputs=[
xtts_checkpoint,
xtts_config,
xtts_vocab
],
outputs=[progress_load],
)
tts_btn.click(
fn=run_tts,
inputs=[
tts_language,
tts_text,
speaker_reference_audio,
],
outputs=[progress_gen, tts_output_audio, reference_audio],
)
demo.launch(
share=True,
debug=False,
server_port=args.port,
server_name="0.0.0.0"
)

View File

@ -225,11 +225,11 @@ class GPTTrainer(BaseTTS):
@torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
test_audios = {}
if self.config.test_sentences:
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
test_audios = {}
print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(

View File

@ -65,7 +65,7 @@ CN_PUNCS_NONSTOP = "
CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
PUNCS = CN_PUNCS + string.punctuation
PUNCS_TRANSFORM = str.maketrans(PUNCS, " " * len(PUNCS), "") # replace puncs with space
PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma
# https://zh.wikipedia.org/wiki/全行和半行

View File

@ -272,6 +272,11 @@ class Xtts(BaseTTS):
style_embs = []
for i in range(0, audio.shape[1], 22050 * chunk_length):
audio_chunk = audio[:, i : i + 22050 * chunk_length]
# if the chunk is too short ignore it
if audio_chunk.size(-1) < 22050 * 0.33:
continue
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),

View File

@ -332,9 +332,9 @@ class ModelManager(object):
def ask_tos(model_full_path):
"""Ask the user to agree to the terms of service"""
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
print(" > You must agree to the terms of service to use this model.")
print(" | > Please see the terms of service at https://coqui.ai/cpml.txt")
print(' | > "I have read, understood and agreed to the Terms and Conditions." - [y/n]')
print(" > You must confirm the following:")
print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
answer = input(" | | > ")
if answer.lower() == "y":
with open(tos_path, "w", encoding="utf-8") as f:

View File

@ -56,4 +56,4 @@ ModelConfig()
In the example above, ```ModelConfig()``` is the final configuration that the model receives and it has all the fields necessary for the model.
We host pre-defined model configurations under ```TTS/<model_class>/configs/```.Although we recommend a unified config class, you can decompose it as you like as for your custom models as long as all the fields for the trainer, model, and inference APIs are provided.
We host pre-defined model configurations under ```TTS/<model_class>/configs/```. Although we recommend a unified config class, you can decompose it as you like as for your custom models as long as all the fields for the trainer, model, and inference APIs are provided.

View File

@ -21,7 +21,7 @@ them and fine-tune it for your own dataset. This will help you in two main ways:
Fine-tuning comes to the rescue in this case. You can take one of our pre-trained models and fine-tune it on your own
speech dataset and achieve reasonable results with only a couple of hours of data.
However, note that, fine-tuning does not ensure great results. The model performance is still depends on the
However, note that, fine-tuning does not ensure great results. The model performance still depends on the
{ref}`dataset quality <what_makes_a_good_dataset>` and the hyper-parameters you choose for fine-tuning. Therefore,
it still takes a bit of tinkering.
@ -41,7 +41,7 @@ them and fine-tune it for your own dataset. This will help you in two main ways:
tts --list_models
```
The command above lists the the models in a naming format as ```<model_type>/<language>/<dataset>/<model_name>```.
The command above lists the models in a naming format as ```<model_type>/<language>/<dataset>/<model_name>```.
Or you can manually check the `.model.json` file in the project directory.

View File

@ -7,7 +7,7 @@ If you have a single audio file and you need to split it into clips, there are d
It is also important to use a lossless audio file format to prevent compression artifacts. We recommend using `wav` file format.
Let's assume you created the audio clips and their transcription. You can collect all your clips under a folder. Let's call this folder `wavs`.
Let's assume you created the audio clips and their transcription. You can collect all your clips in a folder. Let's call this folder `wavs`.
```
/wavs
@ -17,7 +17,7 @@ Let's assume you created the audio clips and their transcription. You can collec
...
```
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimitered by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text.
You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimited by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text.
We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc.
@ -55,7 +55,7 @@ For more info about dataset qualities and properties check our [post](https://gi
After you collect and format your dataset, you need to check two things. Whether you need a `formatter` and a `text_cleaner`. The `formatter` loads the text file (created above) as a list and the `text_cleaner` performs a sequence of text normalization operations that converts the raw text into the spoken representation (e.g. converting numbers to text, acronyms, and symbols to the spoken format).
If you use a different dataset format then the LJSpeech or the other public datasets that 🐸TTS supports, then you need to write your own `formatter`.
If you use a different dataset format than the LJSpeech or the other public datasets that 🐸TTS supports, then you need to write your own `formatter`.
If your dataset is in a new language or it needs special normalization steps, then you need a new `text_cleaner`.

View File

@ -2,11 +2,11 @@
- Language frontends are located under `TTS.tts.utils.text`
- Each special language has a separate folder.
- Each folder containst all the utilities for processing the text input.
- Each folder contains all the utilities for processing the text input.
- `TTS.tts.utils.text.phonemizers` contains the main phonemizer for a language. This is the class that uses the utilities
from the previous step and used to convert the text to phonemes or graphemes for the model.
- After you implement your phonemizer, you need to add it to the `TTS/tts/utils/text/phonemizers/__init__.py` to be able to
map the language code in the model config - `config.phoneme_language` - to the phonemizer class and initiate the phonemizer automatically.
- You should also add tests to `tests/text_tests` if you want to make a PR.
We suggest you to check the available implementations as reference. Good luck!
We suggest you to check the available implementations as reference. Good luck!

View File

@ -145,7 +145,7 @@ class MyModel(BaseTTS):
Args:
ap (AudioProcessor): audio processor used at training.
batch (Dict): Model inputs used at the previous training step.
outputs (Dict): Model outputs generated at the previoud training step.
outputs (Dict): Model outputs generated at the previous training step.
Returns:
Tuple[Dict, np.ndarray]: training plots and output waveform.
@ -183,7 +183,7 @@ class MyModel(BaseTTS):
...
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
"""Setup an return optimizer or optimizers."""
"""Setup a return optimizer or optimizers."""
pass
def get_lr(self) -> Union[float, List[float]]:

View File

@ -2,13 +2,13 @@
## What is Mary-TTS?
[Mary (Modular Architecture for Research in sYynthesis) Text-to-Speech](http://mary.dfki.de/) is an open-source (GNU LGPL license), multilingual Text-to-Speech Synthesis platform written in Java. It was originally developed as a collaborative project of [DFKIs](http://www.dfki.de/web) Language Technology Lab and the [Institute of Phonetics](http://www.coli.uni-saarland.de/groups/WB/Phonetics/) at Saarland University, Germany. It is now maintained by the Multimodal Speech Processing Group in the [Cluster of Excellence MMCI](https://www.mmci.uni-saarland.de/) and DFKI.
[Mary (Modular Architecture for Research in sYnthesis) Text-to-Speech](http://mary.dfki.de/) is an open-source (GNU LGPL license), multilingual Text-to-Speech Synthesis platform written in Java. It was originally developed as a collaborative project of [DFKIs](http://www.dfki.de/web) Language Technology Lab and the [Institute of Phonetics](http://www.coli.uni-saarland.de/groups/WB/Phonetics/) at Saarland University, Germany. It is now maintained by the Multimodal Speech Processing Group in the [Cluster of Excellence MMCI](https://www.mmci.uni-saarland.de/) and DFKI.
MaryTTS has been around for a very! long time. Version 3.0 even dates back to 2006, long before Deep Learning was a broadly known term and the last official release was version 5.2 in 2016.
You can check out this OpenVoice-Tech page to learn more: https://openvoice-tech.net/index.php/MaryTTS
## Why Mary-TTS compatibility is relevant
Due to it's open-source nature, relatively high quality voices and fast synthetization speed Mary-TTS was a popular choice in the past and many tools implemented API support over the years like screen-readers (NVDA + SpeechHub), smart-home HUBs (openHAB, Home Assistant) or voice assistants (Rhasspy, Mycroft, SEPIA). A compatibility layer for Coqui-TTS will ensure that these tools can use Coqui as a drop-in replacement and get even better voices right away.
Due to its open-source nature, relatively high quality voices and fast synthetization speed Mary-TTS was a popular choice in the past and many tools implemented API support over the years like screen-readers (NVDA + SpeechHub), smart-home HUBs (openHAB, Home Assistant) or voice assistants (Rhasspy, Mycroft, SEPIA). A compatibility layer for Coqui-TTS will ensure that these tools can use Coqui as a drop-in replacement and get even better voices right away.
## API and code examples
@ -40,4 +40,4 @@ You can enter the same URLs in your browser and check-out the results there as w
### How it works and limitations
A classic Mary-TTS server would usually show all installed locales and voices via the corresponding endpoints and accept the parameters `LOCALE` and `VOICE` for processing. For Coqui-TTS we usually start the server with one specific locale and model and thus cannot return all available options. Instead we return the active locale and use the model name as "voice". Since we only have one active model and always want to return a WAV-file, we currently ignore all other processing parameters except `INPUT_TEXT`. Since the gender is not defined for models in Coqui-TTS we always return `u` (undefined).
We think that this is an acceptable compromise, since users are often only interested in one specific voice anyways, but the API might get extended in the future to support multiple languages and voices at the same time.
We think that this is an acceptable compromise, since users are often only interested in one specific voice anyways, but the API might get extended in the future to support multiple languages and voices at the same time.

View File

@ -1,6 +1,6 @@
# 🐢 Tortoise
Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input
text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to
text to discritized acoustic tokens, a diffusion model that converts these tokens to melspectrogram frames and a Univnet vocoder to convert the spectrograms to
the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS.
Big thanks to 👑[@manmay-nakhashi](https://github.com/manmay-nakhashi) who helped us implement Tortoise in 🐸TTS.

View File

@ -81,42 +81,6 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
language="en")
```
##### Streaming inference
XTTS supports streaming inference. This is useful for real-time applications.
```python
import os
import time
import torch
import torchaudio
print("Loading model...")
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
model = tts.synthesizer.tts_model
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...")
t0 = time.time()
stream_generator = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
)
wav_chuncks = []
for i, chunk in enumerate(stream_generator):
if i == 0:
print(f"Time to first chunck: {time.time() - t0}")
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```
#### 🐸TTS Command line
##### Single reference
@ -150,14 +114,32 @@ or for all wav files in a directory you can use:
To use the model API, you need to download the model files and pass config and model file paths manually.
##### Calling manually
#### Manual Inference
If you want to be able to run with `use_deepspeed=True` and **enjoy the speedup**, you need to install deepspeed first.
If you want to be able to `load_checkpoint` with `use_deepspeed=True` and **enjoy the speedup**, you need to install deepspeed first.
```console
pip install deepspeed==0.10.3
```
##### inference parameters
- `text`: The text to be synthesized.
- `language`: The language of the text to be synthesized.
- `gpt_cond_latent`: The latent vector you get with get_conditioning_latents. (You can cache for faster inference with same speaker)
- `speaker_embedding`: The speaker embedding you get with get_conditioning_latents. (You can cache for faster inference with same speaker)
- `temperature`: The softmax temperature of the autoregressive model. Defaults to 0.65.
- `length_penalty`: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. Defaults to 1.0.
- `repetition_penalty`: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0.
- `top_k`: Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 50.
- `top_p`: Lower values mean the decoder produces more "likely" (aka boring) outputs. Defaults to 0.8.
- `speed`: The speed rate of the generated audio. Defaults to 1.0. (can produce artifacts if far from 1.0)
- `enable_text_splitting`: Whether to split the text into sentences and generate audio for each sentence. It allows you to have infinite input length but might loose important context between sentences. Defaults to True.
##### Inference
```python
import os
import torch
@ -233,6 +215,50 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
### Training
#### Easy training
To make `XTTS_v2` GPT encoder training easier for beginner users we did a gradio demo that implements the whole fine-tuning pipeline. The gradio demo enables the user to easily do the following steps:
- Preprocessing of the uploaded audio or audio files in 🐸 TTS coqui formatter
- Train the XTTS GPT encoder with the processed data
- Inference support using the fine-tuned model
The user can run this gradio demo locally or remotely using a Colab Notebook.
##### Run demo on Colab
To make the `XTTS_v2` fine-tuning more accessible for users that do not have good GPUs available we did a Google Colab Notebook.
The Colab Notebook is available [here](https://colab.research.google.com/drive/1GiI4_X724M8q2W-zZ-jXo7cWTV7RfaH-?usp=sharing).
To learn how to use this Colab Notebook please check the [XTTS fine-tuning video]().
If you are not able to acess the video you need to follow the steps:
1. Open the Colab notebook and start the demo by runining the first two cells (ignore pip install errors in the first one).
2. Click on the link "Running on public URL:" on the second cell output.
3. On the first Tab (1 - Data processing) you need to select the audio file or files, wait for upload, and then click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done.
4. Soon as the dataset processing is done you need to go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. Note that it can take up to 40 minutes.
5. Soon the training is done you can go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded. Then you can do the inference on the model by clicking on the button "Step 4 - Inference".
##### Run demo locally
To run the demo locally you need to do the following steps:
1. Install 🐸 TTS following the instructions available [here](https://tts.readthedocs.io/en/dev/installation.html#installation).
2. Install the Gradio demo requirements with the command `python3 -m pip install -r TTS/demos/xtts_ft_demo/requirements.txt`
3. Run the Gradio demo using the command `python3 TTS/demos/xtts_ft_demo/xtts_demo.py`
4. Follow the steps presented in the [tutorial video](https://www.youtube.com/watch?v=8tpDiiouGxc&feature=youtu.be) to be able to fine-tune and test the fine-tuned model.
If you are not able to access the video, here is what you need to do:
1. On the first Tab (1 - Data processing) select the audio file or files, wait for upload
2. Click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done.
3. Go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. it will take some time.
4. Go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded.
5. Now you can run inference with the model by clicking on the button "Step 4 - Inference".
#### Advanced training
A recipe for `XTTS_v2` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
@ -280,6 +306,7 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
```
## References and Acknowledgements
- VallE: https://arxiv.org/abs/2301.02111
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts