Implement coarse model finetuning

finetune_bark
Eren G??lge 2023-08-21 12:33:56 +02:00
parent 0b1312c0c0
commit 3bd9f217dc
1 changed files with 116 additions and 35 deletions

View File

@ -63,7 +63,9 @@ class BarkDataset(Dataset):
return {
"raw_text": raw_text,
"text_len": len(raw_text),
"wav": wav,
"wav_len": wav.shape[1],
"wav_file": wav_filename,
"speaker_name": item["speaker_name"],
"language_name": item["language"],
@ -143,9 +145,9 @@ class Bark(BaseTTS):
super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None)
self.config.num_chars = len(tokenizer)
self.tokenizer = tokenizer
self.semantic_model = GPT(config.semantic_config)
self.coarse_model = GPT(config.coarse_config)
self.fine_model = FineGPT(config.fine_config)
self.semantic_model = GPT(config.semantic_gpt_config)
self.coarse_model = GPT(config.coarse_gpt_config)
self.fine_model = FineGPT(config.fine_gpt_config)
self.encodec = EncodecModel.encodec_model_24khz()
self.encodec.set_target_bandwidth(6.0)
self.semantic_tokenizer = BarkHubertAudioTokenizer(self.config, lazy_load=self.config.training_mode)
@ -154,11 +156,19 @@ class Bark(BaseTTS):
def device(self):
return next(self.parameters()).device
@property
def pad_token(self):
if self.config.training_mode == "semantic":
return self.config.SEMANTIC_PAD_TOKEN
elif self.config.training_mode in ["coarse", "fine"]:
return self.config.COARSE_SEMANTIC_PAD_TOKEN
else:
raise ValueError("Invalid training mode: {}".format(self.config.training_mode))
def load_bark_models(self):
self.semantic_model, self.config = load_model(
ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text"
)
self.coarse_model, self.config = load_model(
ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"],
device=self.device,
@ -169,17 +179,19 @@ class Bark(BaseTTS):
ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine"
)
def generate_coarse_fine_tokens(self, audio):
def generate_coarse_fine_tokens(
self,
audio,
):
if isinstance(audio, str):
audio, sr = torchaudio.load(audio)
audio = convert_audio_and_make_label(audio, sr, self.config.sample_rate, self.encodec.channels)
audio = convert_audio(audio, sr, self.config.sample_rate, self.encodec.channels)
audio = audio.unsqueeze(0).to(self.device)
# Coarse and fine tokens
with torch.no_grad():
encoded_frames = self.encodec.encode(audio)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
codes = codes.cpu().numpy()
return codes, codes[:2, :] # fine, corse
def generate_semantic_tokens(self, audio):
@ -354,14 +366,6 @@ class Bark(BaseTTS):
Returns:
formatted batch
"""
PAD_TOKEN = None
if self.config.training_mode == "semantic":
PAD_TOKEN = self.config.SEMANTIC_PAD_TOKEN
elif self.config.training_mode in ["coarse", "fine"]:
PAD_TOKEN = self.config.COARSE_SEMANTIC_PAD_TOKEN
else:
raise ValueError("Invalid training mode: {}".format(self.config.training_mode))
tokenss = []
max_len = 0
for i, text in enumerate(batch["raw_text"]):
@ -370,11 +374,12 @@ class Bark(BaseTTS):
tokenss.append(tokens)
max_len = max(max_len, len(tokens))
# pad and collate into batch
for i, tokens in enumerate(tokenss):
tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=PAD_TOKEN)
tokens = torch.stack(tokenss, dim=0)
batch["input_ids"] = tokens[:, : self.config.max_text_tokens_len]
if self.config.training_mode == "semantic":
# pad and collate into batch
for i, tokens in enumerate(tokenss):
tokenss[i] = torch.nn.functional.pad(tokens, (0, max_len - len(tokens)), value=self.pad_token)
tokens = torch.stack(tokenss, dim=0)
batch["input_ids"] = tokens[:, : self.config.train_semantic_data_settings["max_text_tokens_len"]]
return batch
def format_batch_on_device(self, batch):
@ -386,11 +391,36 @@ class Bark(BaseTTS):
Returns:
dict: Formatted batch.
"""
# TODO: Make padding and truncation based on exact length of the waveforms
if self.config.training_mode == "semantic":
batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[
:, : self.config.max_semantic_tokens_len
]
elif self.config.training_mode == "coarse":
semantic_to_coarse_ratio = (
self.config.COARSE_RATE_HZ / self.config.SEMANTIC_RATE_HZ * self.config.N_COARSE_CODEBOOKS
)
batch["semantic_tokens"] = self.generate_semantic_tokens(batch["waveform"][:, 0])[
:, : self.config.train_coarse_data_settings["max_semantic_tokens_len"]
]
batch["semantic_tokens"] = torch.nn.functional.pad(
batch["semantic_tokens"], (0, 1), value=self.config.COARSE_INFER_TOKEN
)
batch["coarse_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[1]
batch["coarse_tokens"] = (
batch["coarse_tokens"].flatten(start_dim=1)
+ self.config.CODEBOOK_SIZE
+ self.config.SEMANTIC_VOCAB_SIZE
)
batch["coarse_tokens"] = batch["coarse_tokens"][
:, : self.config.train_coarse_data_settings["max_coarse_tokens_len"]
]
elif self.config.training_mode == "fine":
batch["coarse_tokens"], batch["fine_tokens"] = self.generate_coarse_fine_tokens(batch["waveform"])[
:, : self.config.max_coarse_tokens_len
]
return batch
def train_step_semantic(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
@ -402,16 +432,26 @@ class Bark(BaseTTS):
inputs = torch.cat([batch["input_ids"], input_tokens], dim=1)
logits = self.semantic_model(inputs)
semantic_logits = logits[:, batch["input_ids"].size(1) :].contiguous()
logits = logits[:, batch["input_ids"].size(1) :].contiguous()
loss = criterion(
semantic_logits.view(-1, self.config.semantic_config.output_vocab_size), target_tokens.view(-1)
)
loss = criterion(logits.view(-1, self.config.semantic_gpt_config.output_vocab_size), target_tokens.view(-1))
loss_dict = {"loss": loss}
return {}, loss_dict
def train_step_coarse(self):
...
def train_step_coarse(self, batch: dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
"""Train coarse encoder"""
tokens = batch["coarse_tokens"]
target_tokens = tokens[:, 1:].contiguous()
input_tokens = tokens[:, :-1].contiguous()
inputs = torch.cat([batch["semantic_tokens"], input_tokens], dim=1)
logits = self.coarse_model(inputs)
logits = logits[:, batch["semantic_tokens"].size(1) :].contiguous()
loss = criterion(logits.view(-1, self.config.coarse_gpt_config.output_vocab_size), target_tokens.view(-1))
loss_dict = {"loss": loss}
return {}, loss_dict
def train_step_fine(self):
...
@ -419,6 +459,10 @@ class Bark(BaseTTS):
def train_step(self, *args, **kwargs):
if self.config.training_mode == "semantic":
return self.train_step_semantic(*args, **kwargs)
elif self.config.training_mode == "coarse":
return self.train_step_coarse(*args, **kwargs)
elif self.config.training_mode == "fine":
raise NotImplemented()
def eval_step(self, *args, **kwargs):
self.train_step(*args, **kwargs)
@ -436,26 +480,23 @@ class Bark(BaseTTS):
return None
def get_criterion(self):
if self.config.training_mode in ["semantic", "coarse_encoder"]:
return torch.nn.CrossEntropyLoss(ignore_index=self.config.COARSE_SEMANTIC_PAD_TOKEN)
elif self.config.training_mode == "fine_encoder":
return torch.nn.CrossEntropyLoss(ignore_index=self.config.FINE_COARSE_PAD_TOKEN)
else:
raise ValueError(f" ❗ Invalid training mode {self.config.training_mode}")
return torch.nn.CrossEntropyLoss(ignore_index=self.pad_token)
def get_optimizer(self):
if self.config.training_mode == "semantic":
optimizer = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.semantic_model
)
elif self.config.training_mode == "coarse_encoder":
elif self.config.training_mode == "coarse":
optimizer = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.coarse_model
)
elif self.config.training_mode == "fine_encoder":
elif self.config.training_mode == "fine":
optimizer = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr, self.fine_model
)
else:
raise ValueError(" ❗ Invalid training mode: {}".format(self.config.training_mode))
return optimizer
def get_scheduler(self, optimizer):
@ -546,12 +587,52 @@ class Bark(BaseTTS):
if __name__ == "__main__":
# from TTS.tts.configs.bark_config import BarkConfig
# bark_config = BarkConfig()
# bark_config.training_mode = "semantic"
# bark_config.batch_size = 2
# bark = Bark.init_from_config(bark_config)
# # batch = {"waveform": torch.randn(2, 48000), "raw_text": ["hello world", "how are you"]}
# # batch = bark.format_batch(batch)
# # batch = bark.format_batch_on_device(batch)
# from trainer import Trainer, TrainerArgs
# dataset_config = BaseDatasetConfig(
# formatter="ljspeech", meta_file_train="metadata.csv", path="/data/TTS-public/tests/data/ljspeech/"
# )
# train_samples, eval_samples = load_tts_samples(
# dataset_config,
# eval_split=True,
# eval_split_max_size=4,
# eval_split_size=4,
# )
# trainer = Trainer(
# model=bark,
# config=bark_config,
# output_path="./",
# args=TrainerArgs(),
# train_samples=train_samples,
# eval_samples=eval_samples,
# )
# trainer.fit()
from TTS.tts.configs.bark_config import BarkConfig
bark_config = BarkConfig()
bark_config.training_mode = "semantic"
bark_config.training_mode = "coarse"
bark_config.batch_size = 2
bark_config.run_eval = False
bark_config.save_checkpoints = False
bark_config.save_best_after = 100000
bark_config.print_step = 1
bark = Bark.init_from_config(bark_config)