mirror of https://github.com/coqui-ai/TTS.git
Implement coarse model finetuning
parent
0b1312c0c0
commit
3bd9f217dc
|
@ -63,7 +63,9 @@ class BarkDataset(Dataset):
|
|||
|
||||
return {
|
||||
"raw_text": raw_text,
|
||||
"text_len": len(raw_text),
|
||||
"wav": wav,
|
||||
"wav_len": wav.shape[1],
|
||||
"wav_file": wav_filename,
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
|
@ -143,9 +145,9 @@ class Bark(BaseTTS):
|
|||
super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None)
|
||||
self.config.num_chars = len(tokenizer)
|
||||
self.tokenizer = tokenizer
|
||||
self.semantic_model = GPT(config.semantic_config)
|
||||
self.coarse_model = GPT(config.coarse_config)
|
||||
self.fine_model = FineGPT(config.fine_config)
|
||||
self.semantic_model = GPT(config.semantic_gpt_config)
|
||||
self.coarse_model = GPT(config.coarse_gpt_config)
|
||||
self.fine_model = FineGPT(config.fine_gpt_config)
|
||||
self.encodec = EncodecModel.encodec_model_24khz()
|
||||
self.encodec.set_target_bandwidth(6.0)
|
||||
self.semantic_tokenizer = BarkHubertAudioTokenizer(self.config, lazy_load=self.config.training_mode)
|
||||
|
@ -154,11 +156,19 @@ class Bark(BaseTTS):
|
|||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
if self.config.training_mode == "semantic":
|
||||
return self.config.SEMANTIC_PAD_TOKEN
|
||||
elif self.config.training_mode in ["coarse", "fine"]:
|
||||
return self.config.COARSE_SEMANTIC_PAD_TOKEN
|
||||
else:
|
||||
raise ValueError("Invalid training mode: {}".format(self.config.training_mode))
|
||||
|
||||
def load_bark_models(self):
|
||||
self.semantic_model, self.config = load_model(
|
||||
ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text"
|
||||
)
|
||||
|
||||
self.coarse_model, self.config = load_model(
|
||||
ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"],
|
||||
device=self.device,
|
||||
|
@ -169,17 +179,19 @@ class Bark(BaseTTS):
|
|||
ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine"
|
||||
)
|
||||
|
||||
def generate_coarse_fine_tokens(self, audio):
|
||||
def generate_coarse_fine_tokens(
|
||||
self,
|
||||
audio,
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio, sr = torchaudio.load(audio)
|
||||
audio = convert_audio_and_make_label(audio, sr, self.config.sample_rate, self.encodec.channels)
|
||||
audio = convert_audio(audio, sr, self.config.sample_rate, self.encodec.channels)
|
||||
audio = audio.unsqueeze(0).to(self.device)
|
||||
|
||||
# Coarse and fine tokens
|
||||
with torch.no_grad():
|
||||
encoded_frames = self.encodec.encode(audio)
|
||||
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
|
||||
codes = codes.cpu().numpy()
|
||||
return codes, codes[:2, :] # fine, corse
|
||||
|
||||
def generate_semantic_tokens(self, audio):
|
||||
|
@ -354,14 +366,6 @@ class Bark(BaseTTS):
|
|||
Returns:
|
||||
formatted batch
|
||||
"""
|
||||
PAD_TOKEN = None
|
||||
if self.config.training_mode == "semantic":
|
||||
PAD_TOKEN = self.config.SEMANTIC_PAD_TOKEN
|
||||
elif self.config.training_mode in ["coarse", "fine"]:
|
||||
PAD_TOKEN = self.config.COARSE_SEMANTIC_PAD_TOKEN
|
||||
else:
|
||||
raise ValueError("Invalid training mode: {}".format(self.config.training_mode))
|
||||
|
||||
tokenss = []
|
||||
max_len = 0
|
||||
for i, text in enumerate(batch["raw_text"]):
|
||||
|
@ -370,11 +374,12 @@ class Bark(BaseTTS):
|
|||
tokenss.append(tokens)
|
||||
max_len = max(max_len, len(tokens))
|
||||
|
||||
# pad and collate into batch
|
||||
for i, tokens in enumerate(tokenss):
|
||||
tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=PAD_TOKEN)
|
||||
tokens = torch.stack(tokenss, dim=0)
|
||||
batch["input_ids"] = tokens[:, : self.config.max_text_tokens_len]
|
||||
if self.config.training_mode == "semantic":
|
||||
# pad and collate into batch
|
||||
for i, tokens in enumerate(tokenss):
|
||||
tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=self.pad_token)
|
||||
tokens = torch.stack(tokenss, dim=0)
|
||||
batch["input_ids"] = tokens[:, : self.config.train_semantic_data_settings["max_text_tokens_len"]]
|
||||
return batch
|
||||
|
||||
def format_batch_on_device(self, batch):
|
||||
|
@ -386,11 +391,36 @@ class Bark(BaseTTS):
|
|||
Returns:
|
||||
dict: Formatted batch.
|
||||
"""
|
||||
# TODO: Make padding and truncation based on exact length of the waveforms
|
||||
if self.config.training_mode == "semantic":
|
||||
batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[
|
||||
:, : self.config.max_semantic_tokens_len
|
||||
]
|
||||
elif self.config.training_mode == "coarse":
|
||||
semantic_to_coarse_ratio = (
|
||||
self.config.COARSE_RATE_HZ / self.config.SEMANTIC_RATE_HZ * self.config.N_COARSE_CODEBOOKS
|
||||
)
|
||||
|
||||
batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[
|
||||
:, : self.config.train_coarse_data_settings["max_semantic_tokens_len"]
|
||||
]
|
||||
batch["semantic_tokens"] = torch.nn.functional.pad(
|
||||
batch["semantic_tokens"], (0, 1), value=self.config.COARSE_INFER_TOKEN
|
||||
)
|
||||
|
||||
batch["coarse_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[1]
|
||||
batch["coarse_tokens"] = (
|
||||
batch["coarse_tokens"].flatten(start_dim=1)
|
||||
+ self.config.CODEBOOK_SIZE
|
||||
+ self.config.SEMANTIC_VOCAB_SIZE
|
||||
)
|
||||
batch["coarse_tokens"] = batch["coarse_tokens"][
|
||||
:, : self.config.train_coarse_data_settings["max_coarse_tokens_len"]
|
||||
]
|
||||
elif self.config.training_mode == "fine":
|
||||
batch["coarse_tokens"], batch["fine_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[
|
||||
:, : self.config.max_coarse_tokens_len
|
||||
]
|
||||
return batch
|
||||
|
||||
def train_step_semantic(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
|
||||
|
@ -402,16 +432,26 @@ class Bark(BaseTTS):
|
|||
inputs = torch.cat([batch["input_ids"], input_tokens], dim=1)
|
||||
logits = self.semantic_model(inputs)
|
||||
|
||||
semantic_logits = logits[:, batch["input_ids"].size(1) :].contiguous()
|
||||
logits = logits[:, batch["input_ids"].size(1) :].contiguous()
|
||||
|
||||
loss = criterion(
|
||||
semantic_logits.view(-1, self.config.semantic_config.output_vocab_size), target_tokens.view(-1)
|
||||
)
|
||||
loss = criterion(logits.view(-1, self.config.semantic_gpt_config.output_vocab_size), target_tokens.view(-1))
|
||||
loss_dict = {"loss": loss}
|
||||
return {}, loss_dict
|
||||
|
||||
def train_step_coarse(self):
|
||||
...
|
||||
def train_step_coarse(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Train coarse encoder"""
|
||||
tokens = batch["coarse_tokens"]
|
||||
target_tokens = tokens[:, 1:].contiguous()
|
||||
input_tokens = tokens[:, :-1].contiguous()
|
||||
|
||||
inputs = torch.cat([batch["semantic_tokens"], input_tokens], dim=1)
|
||||
logits = self.coarse_model(inputs)
|
||||
|
||||
logits = logits[:, batch["semantic_tokens"].size(1) :].contiguous()
|
||||
|
||||
loss = criterion(logits.view(-1, self.config.coarse_gpt_config.output_vocab_size), target_tokens.view(-1))
|
||||
loss_dict = {"loss": loss}
|
||||
return {}, loss_dict
|
||||
|
||||
def train_step_fine(self):
|
||||
...
|
||||
|
@ -419,6 +459,10 @@ class Bark(BaseTTS):
|
|||
def train_step(self, *args, **kwargs):
|
||||
if self.config.training_mode == "semantic":
|
||||
return self.train_step_semantic(*args, **kwargs)
|
||||
elif self.config.training_mode == "coarse":
|
||||
return self.train_step_coarse(*args, **kwargs)
|
||||
elif self.config.training_mode == "fine":
|
||||
raise NotImplemented()
|
||||
|
||||
def eval_step(self, *args, **kwargs):
|
||||
self.train_step(*args, **kwargs)
|
||||
|
@ -436,26 +480,23 @@ class Bark(BaseTTS):
|
|||
return None
|
||||
|
||||
def get_criterion(self):
|
||||
if self.config.training_mode in ["semantic", "coarse_encoder"]:
|
||||
return torch.nn.CrossEntropyLoss(ignore_index=self.config.COARSE_SEMANTIC_PAD_TOKEN)
|
||||
elif self.config.training_mode == "fine_encoder":
|
||||
return torch.nn.CrossEntropyLoss(ignore_index=self.config.FINE_COARSE_PAD_TOKEN)
|
||||
else:
|
||||
raise ValueError(f" ❗ Invalid training mode {self.config.training_mode}")
|
||||
return torch.nn.CrossEntropyLoss(ignore_index=self.pad_token)
|
||||
|
||||
def get_optimizer(self):
|
||||
if self.config.training_mode == "semantic":
|
||||
optimizer = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.semantic_model
|
||||
)
|
||||
elif self.config.training_mode == "coarse_encoder":
|
||||
elif self.config.training_mode == "coarse":
|
||||
optimizer = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.coarse_model
|
||||
)
|
||||
elif self.config.training_mode == "fine_encoder":
|
||||
elif self.config.training_mode == "fine":
|
||||
optimizer = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.fine_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(" ❗ Invalid training mode: {}".format(self.config.training_mode))
|
||||
return optimizer
|
||||
|
||||
def get_scheduler(self, optimizer):
|
||||
|
@ -546,12 +587,52 @@ class Bark(BaseTTS):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from TTS.tts.configs.bark_config import BarkConfig
|
||||
|
||||
# bark_config = BarkConfig()
|
||||
|
||||
# bark_config.training_mode = "semantic"
|
||||
# bark_config.batch_size = 2
|
||||
|
||||
# bark = Bark.init_from_config(bark_config)
|
||||
|
||||
# # batch = {"waveform": torch.randn(2, 48000), "raw_text": ["hello world", "how are you"]}
|
||||
# # batch = bark.format_batch(batch)
|
||||
# # batch = bark.format_batch_on_device(batch)
|
||||
|
||||
# from trainer import Trainer, TrainerArgs
|
||||
|
||||
# dataset_config = BaseDatasetConfig(
|
||||
# formatter="ljspeech", meta_file_train="metadata.csv", path="/data/TTS-public/tests/data/ljspeech/"
|
||||
# )
|
||||
|
||||
# train_samples, eval_samples = load_tts_samples(
|
||||
# dataset_config,
|
||||
# eval_split=True,
|
||||
# eval_split_max_size=4,
|
||||
# eval_split_size=4,
|
||||
# )
|
||||
|
||||
# trainer = Trainer(
|
||||
# model=bark,
|
||||
# config=bark_config,
|
||||
# output_path="./",
|
||||
# args=TrainerArgs(),
|
||||
# train_samples=train_samples,
|
||||
# eval_samples=eval_samples,
|
||||
# )
|
||||
# trainer.fit()
|
||||
|
||||
from TTS.tts.configs.bark_config import BarkConfig
|
||||
|
||||
bark_config = BarkConfig()
|
||||
|
||||
bark_config.training_mode = "semantic"
|
||||
bark_config.training_mode = "coarse"
|
||||
bark_config.batch_size = 2
|
||||
bark_config.run_eval = False
|
||||
bark_config.save_checkpoints = False
|
||||
bark_config.save_best_after = 100000
|
||||
bark_config.print_step = 1
|
||||
|
||||
bark = Bark.init_from_config(bark_config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue