mirror of https://github.com/coqui-ai/TTS.git
Ensures that only GPT model is in training mode during XTTS GPT training (#3241)
* Ensures that only GPT model is in training mode during training * Fix parallel wavegan unit testpull/3239/head
parent
14579a4607
commit
11283fce07
|
@ -318,9 +318,10 @@ class GPTTrainer(BaseTTS):
|
|||
batch["cond_idxs"] = None
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
||||
self.dvae = self.dvae.eval()
|
||||
def on_train_epoch_start(self, trainer):
|
||||
trainer.model.eval() # the whole model to eval
|
||||
# put gpt model in training mode
|
||||
trainer.model.xtts.gpt.train()
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
# ignore similarities.pth on clearml save/upload
|
||||
|
|
|
@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
|||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
steps_to_start_discriminator: int = 200000
|
||||
target_loss: str = "loss_1"
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
|
|
|
@ -27,7 +27,7 @@ pandas>=1.4,<2.0
|
|||
# deps for training
|
||||
matplotlib>=3.7.0
|
||||
# coqui stack
|
||||
trainer
|
||||
trainer>=0.0.32
|
||||
# config management
|
||||
coqpit>=0.0.16
|
||||
# chinese g2p deps
|
||||
|
|
Loading…
Reference in New Issue