From 5307a2229baed242f3f60ade3fa113c69f1f9e7f Mon Sep 17 00:00:00 2001 From: Victor Shepardson Date: Tue, 1 Nov 2022 12:52:06 +0100 Subject: [PATCH] Fix Capacitron training (#2086) --- TTS/tts/models/base_tts.py | 2 +- TTS/utils/capacitron_optimizer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 88c84d08..d2222acb 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -344,7 +344,7 @@ class BaseTTS(BaseTrainerModel): loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. + shuffle=True, # if there is no other sampler collate_fn=dataset.collate_fn, drop_last=False, # setting this False might cause issues in AMP training. sampler=sampler, diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py index fac7d8a0..7206ffd5 100644 --- a/TTS/utils/capacitron_optimizer.py +++ b/TTS/utils/capacitron_optimizer.py @@ -38,9 +38,9 @@ class CapacitronOptimizer: self.param_groups = self.primary_optimizer.param_groups self.primary_optimizer.step() - def zero_grad(self): - self.primary_optimizer.zero_grad() - self.secondary_optimizer.zero_grad() + def zero_grad(self, set_to_none=False): + self.primary_optimizer.zero_grad(set_to_none) + self.secondary_optimizer.zero_grad(set_to_none) def load_state_dict(self, state_dict): self.primary_optimizer.load_state_dict(state_dict[0])