Add test sentences during the training

pull/3086/head
Edresson Casanova 2023-10-16 15:32:00 -03:00
parent 2f868dd5c2
commit c4ceaabe2c
3 changed files with 67 additions and 54 deletions

View File

@ -12,7 +12,8 @@ import sys
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.base_tts import BaseTTS
from coqpit import Coqpit
@ -25,20 +26,21 @@ from TTS.tts.datasets.dataset import TTSDataset
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.utils.io import load_fsspec
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
@dataclass
class GPTConfig(TortoiseConfig):
class GPTTrainerConfig(XttsConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
@ -58,7 +60,8 @@ class GPTArgs(XttsArgs):
tokenizer_file: str = ""
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
dvae_checkpoint: str = ""
gpt_checkpoint: str = ""
xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
@ -80,28 +83,18 @@ class GPTTrainer(BaseTTS):
"""
super().__init__(config, ap=None, tokenizer=None)
self.config = config
# init XTTS model
self.xtts = Xtts(self.config)
# create the tokenizer with the target vocabulary
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
# init gpt encoder and hifigan decoder
self.xtts.init_models()
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
self.gpt = GPT(
layers=self.args.gpt_layers,
model_dim=self.args.gpt_n_model_channels,
start_text_token=self.args.gpt_start_text_token,
stop_text_token=self.args.gpt_stop_text_token,
heads=self.args.gpt_n_heads,
max_text_tokens=self.args.gpt_max_text_tokens,
max_mel_tokens=self.args.gpt_max_audio_tokens,
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
number_text_tokens=self.args.gpt_number_text_tokens,
num_audio_tokens=self.args.gpt_num_audio_tokens,
start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token,
).cuda()
if self.args.xtts_checkpoint:
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
# load GPT if available
if self.args.gpt_checkpoint:
@ -122,8 +115,8 @@ class GPTTrainer(BaseTTS):
del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape:
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head)
@ -137,7 +130,7 @@ class GPTTrainer(BaseTTS):
# add new weights to the linear layer (text_head)
text_head_weight = gpt_checkpoint["text_head.weight"]
start_token_row = text_head_weight[-1, :]
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1])
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
text_head_weight[-1, :] = start_token_row
gpt_checkpoint["text_head.weight"] = text_head_weight
@ -150,10 +143,8 @@ class GPTTrainer(BaseTTS):
text_head_bias[-1] = start_token_row
gpt_checkpoint["text_head.bias"] = text_head_bias
self.gpt.load_state_dict(gpt_checkpoint, strict=True)
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
else:
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
# Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
@ -195,6 +186,7 @@ class GPTTrainer(BaseTTS):
# Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
@property
def device(self):
return next(self.parameters()).device
@ -211,12 +203,30 @@ class GPTTrainer(BaseTTS):
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2)
"""
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
return losses
@torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
return {}, {}
if self.config.test_sentences:
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
test_audios = {}
print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
test_audios["{}-audio".format(idx)] = wav
# delete inference layers
del self.xtts.gpt.gpt_inference
del self.xtts.gpt.gpt.wte
return {"audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
def format_batch(self, batch: Dict) -> Dict:
return batch
@ -323,7 +333,7 @@ class GPTTrainer(BaseTTS):
loader = None
else:
# init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
# wait all the DDP process to be ready
if num_gpus > 1:
@ -362,7 +372,7 @@ class GPTTrainer(BaseTTS):
# ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model
net = self.gpt
net = self.xtts.gpt
# normalizations
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
@ -410,7 +420,7 @@ class GPTTrainer(BaseTTS):
self.config.optimizer_params,
self.config.lr,
# optimize only for the GPT model
parameters=self.gpt.parameters(),
parameters=self.xtts.gpt.parameters(),
)
def get_scheduler(self, optimizer) -> List:
@ -432,21 +442,21 @@ class GPTTrainer(BaseTTS):
target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]
# load the model weights
self.gpt.load_state_dict(state, strict=strict)
self.xtts.load_state_dict(state, strict=strict)
if eval:
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.eval()
self.set_inference()
assert not self.training
@staticmethod
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None):
def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (GPTConfig): Model config.
config (GPTTrainerConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""

View File

@ -387,7 +387,7 @@ class Xtts(BaseTTS):
audio = load_audio(audio_path)
audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)
@torch.inference_mode()

View File

@ -3,7 +3,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
@ -265,21 +265,21 @@ def main():
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth",
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, # autoregressive SR
sample_rate=22050, # GPT SR
dvae_sample_rate=22050,
diffusion_sample_rate=24000,
output_sample_rate=24000
)
config = GPTConfig(
config = GPTTrainerConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
@ -313,6 +313,10 @@ def main():
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
]
)
# init the model from config
@ -341,7 +345,7 @@ def main():
if __name__ == "__main__":
RUN_NAME = "GPT_XTTS"
PROJECT_NAME = "XTTS"
PROJECT_NAME = "XTTS_trainer"
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
DASHBOARD_LOGGER = "clearml"
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
@ -352,12 +356,11 @@ if __name__ == "__main__":
GRAD_ACUMM_STEPS = 28
# debug
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
RESTORE_PATH = None
BATCH_SIZE = 10
# DASHBOARD_LOGGER = "tensorboard"
# LOGGER_URI = None
# RESTORE_PATH = None
BATCH_SIZE = 2
GRAD_ACUMM_STEPS = 1
NUM_LOADERS = 1