Ensures that only GPT model is in training mode during XTTS GPT training (#3241)

* Ensures that only GPT model is in training mode during training

* Fix parallel wavegan unit test
pull/3239/head
Edresson Casanova 2023-11-17 11:13:46 -03:00 committed by GitHub
parent 14579a4607
commit 11283fce07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 4 deletions

View File

@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
batch["cond_idxs"] = None
return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613
# guarante that dvae will be in eval mode after .train() on evaluation end
self.dvae = self.dvae.eval()
def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload

View File

@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
use_noise_augment: bool = False
use_cache: bool = True
steps_to_start_discriminator: int = 200000
target_loss: str = "loss_1"
# LOSS PARAMETERS - overrides
use_stft_loss: bool = True

View File

@ -27,7 +27,7 @@ pandas>=1.4,<2.0
# deps for training
matplotlib>=3.7.0
# coqui stack
trainer
trainer>=0.0.32
# config management
coqpit>=0.0.16
# chinese g2p deps