mirror of https://github.com/coqui-ai/TTS.git
Fix checkpointing GAN models (#1641)
* checkpoint sae step crash fix * checkpoint save step crash fix * Update gan.py updated requested changes * crash fixpull/1469/head^2
parent
00e67092d8
commit
577ec406f4
|
@ -185,8 +185,7 @@ class GAN(BaseVocoder):
|
|||
outputs = {"model_outputs": self.y_hat_g}
|
||||
return outputs, loss_dict
|
||||
|
||||
@staticmethod
|
||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Logging shared by the training and evaluation.
|
||||
|
||||
Args:
|
||||
|
@ -198,7 +197,7 @@ class GAN(BaseVocoder):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: log figures and audio samples.
|
||||
"""
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||
y = batch["waveform"]
|
||||
figures = plot_results(y_hat, y, ap, name)
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
|
|
Loading…
Reference in New Issue