mirror of https://github.com/coqui-ai/TTS.git
Add test sentences during the training
parent
2f868dd5c2
commit
c4ceaabe2c
|
@ -12,7 +12,8 @@ import sys
|
|||
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from coqpit import Coqpit
|
||||
|
@ -25,20 +26,21 @@ from TTS.tts.datasets.dataset import TTSDataset
|
|||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
|
||||
@dataclass
|
||||
class GPTConfig(TortoiseConfig):
|
||||
class GPTTrainerConfig(XttsConfig):
|
||||
lr: float = 5e-06
|
||||
training_seed: int = 1
|
||||
optimizer_wd_only_on_weights: bool = False
|
||||
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(XttsAudioConfig):
|
||||
|
@ -58,7 +60,8 @@ class GPTArgs(XttsArgs):
|
|||
tokenizer_file: str = ""
|
||||
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
||||
dvae_checkpoint: str = ""
|
||||
gpt_checkpoint: str = ""
|
||||
xtts_checkpoint: str = ""
|
||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||
|
||||
|
||||
|
@ -80,28 +83,18 @@ class GPTTrainer(BaseTTS):
|
|||
"""
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.config = config
|
||||
# init XTTS model
|
||||
self.xtts = Xtts(self.config)
|
||||
# create the tokenizer with the target vocabulary
|
||||
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||
# init gpt encoder and hifigan decoder
|
||||
self.xtts.init_models()
|
||||
# set mel stats
|
||||
if self.args.mel_norm_file:
|
||||
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||
|
||||
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
||||
|
||||
self.gpt = GPT(
|
||||
layers=self.args.gpt_layers,
|
||||
model_dim=self.args.gpt_n_model_channels,
|
||||
start_text_token=self.args.gpt_start_text_token,
|
||||
stop_text_token=self.args.gpt_stop_text_token,
|
||||
heads=self.args.gpt_n_heads,
|
||||
max_text_tokens=self.args.gpt_max_text_tokens,
|
||||
max_mel_tokens=self.args.gpt_max_audio_tokens,
|
||||
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
|
||||
number_text_tokens=self.args.gpt_number_text_tokens,
|
||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||
start_audio_token=self.args.gpt_start_audio_token,
|
||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
).cuda()
|
||||
|
||||
if self.args.xtts_checkpoint:
|
||||
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
||||
|
||||
# load GPT if available
|
||||
if self.args.gpt_checkpoint:
|
||||
|
@ -122,8 +115,8 @@ class GPTTrainer(BaseTTS):
|
|||
del gpt_checkpoint[key]
|
||||
|
||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape:
|
||||
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
|
||||
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||
|
||||
# add new tokens to a linear layer (text_head)
|
||||
|
@ -137,7 +130,7 @@ class GPTTrainer(BaseTTS):
|
|||
# add new weights to the linear layer (text_head)
|
||||
text_head_weight = gpt_checkpoint["text_head.weight"]
|
||||
start_token_row = text_head_weight[-1, :]
|
||||
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1])
|
||||
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
|
||||
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
||||
text_head_weight[-1, :] = start_token_row
|
||||
gpt_checkpoint["text_head.weight"] = text_head_weight
|
||||
|
@ -150,10 +143,8 @@ class GPTTrainer(BaseTTS):
|
|||
text_head_bias[-1] = start_token_row
|
||||
gpt_checkpoint["text_head.bias"] = text_head_bias
|
||||
|
||||
self.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
else:
|
||||
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
|
@ -195,6 +186,7 @@ class GPTTrainer(BaseTTS):
|
|||
# Mel spectrogram extractor for DVAE
|
||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
@ -211,12 +203,30 @@ class GPTTrainer(BaseTTS):
|
|||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
"""
|
||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
|
||||
return {}, {}
|
||||
if self.config.test_sentences:
|
||||
# init gpt for inference mode
|
||||
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||
self.xtts.gpt.eval()
|
||||
test_audios = {}
|
||||
print(" | > Synthesizing test sentences.")
|
||||
for idx, s_info in enumerate(self.config.test_sentences):
|
||||
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
|
||||
# delete inference layers
|
||||
del self.xtts.gpt.gpt_inference
|
||||
del self.xtts.gpt.gpt.wte
|
||||
return {"audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
return batch
|
||||
|
@ -323,7 +333,7 @@ class GPTTrainer(BaseTTS):
|
|||
loader = None
|
||||
else:
|
||||
# init dataloader
|
||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
|
||||
dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
|
@ -362,7 +372,7 @@ class GPTTrainer(BaseTTS):
|
|||
# ToDo: deal with multi GPU training
|
||||
if self.config.optimizer_wd_only_on_weights:
|
||||
# parameters to only GPT model
|
||||
net = self.gpt
|
||||
net = self.xtts.gpt
|
||||
|
||||
# normalizations
|
||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
|
@ -410,7 +420,7 @@ class GPTTrainer(BaseTTS):
|
|||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
# optimize only for the GPT model
|
||||
parameters=self.gpt.parameters(),
|
||||
parameters=self.xtts.gpt.parameters(),
|
||||
)
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
|
@ -432,21 +442,21 @@ class GPTTrainer(BaseTTS):
|
|||
target_options={"anon": True},
|
||||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]
|
||||
# load the model weights
|
||||
self.gpt.load_state_dict(state, strict=strict)
|
||||
self.xtts.load_state_dict(state, strict=strict)
|
||||
|
||||
if eval:
|
||||
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||
self.eval()
|
||||
self.set_inference()
|
||||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (GPTConfig): Model config.
|
||||
config (GPTTrainerConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
|
|
@ -387,7 +387,7 @@ class Xtts(BaseTTS):
|
|||
audio = load_audio(audio_path)
|
||||
audio = audio[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -3,7 +3,7 @@ from trainer import Trainer, TrainerArgs
|
|||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
|
||||
|
||||
|
||||
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
||||
|
@ -265,21 +265,21 @@ def main():
|
|||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
|
||||
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth",
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, # autoregressive SR
|
||||
sample_rate=22050, # GPT SR
|
||||
dvae_sample_rate=22050,
|
||||
diffusion_sample_rate=24000,
|
||||
output_sample_rate=24000
|
||||
)
|
||||
config = GPTConfig(
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
|
@ -313,6 +313,10 @@ def main():
|
|||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||
]
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
|
@ -341,7 +345,7 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
RUN_NAME = "GPT_XTTS"
|
||||
PROJECT_NAME = "XTTS"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
|
||||
DASHBOARD_LOGGER = "clearml"
|
||||
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
|
||||
|
@ -352,12 +356,11 @@ if __name__ == "__main__":
|
|||
GRAD_ACUMM_STEPS = 28
|
||||
|
||||
# debug
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
RESTORE_PATH = None
|
||||
BATCH_SIZE = 10
|
||||
# DASHBOARD_LOGGER = "tensorboard"
|
||||
# LOGGER_URI = None
|
||||
# RESTORE_PATH = None
|
||||
BATCH_SIZE = 2
|
||||
GRAD_ACUMM_STEPS = 1
|
||||
NUM_LOADERS = 1
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue