mirror of https://github.com/coqui-ai/TTS.git
Fix checkpointing GAN models (#1641)
* checkpoint sae step crash fix * checkpoint save step crash fix * Update gan.py updated requested changes * crash fixpull/1469/head^2
parent
00e67092d8
commit
577ec406f4
|
@ -185,8 +185,7 @@ class GAN(BaseVocoder):
|
||||||
outputs = {"model_outputs": self.y_hat_g}
|
outputs = {"model_outputs": self.y_hat_g}
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
@staticmethod
|
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
||||||
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
|
|
||||||
"""Logging shared by the training and evaluation.
|
"""Logging shared by the training and evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -198,7 +197,7 @@ class GAN(BaseVocoder):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: log figures and audio samples.
|
Tuple[Dict, Dict]: log figures and audio samples.
|
||||||
"""
|
"""
|
||||||
y_hat = outputs[0]["model_outputs"]
|
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
|
||||||
y = batch["waveform"]
|
y = batch["waveform"]
|
||||||
figures = plot_results(y_hat, y, ap, name)
|
figures = plot_results(y_hat, y, ap, name)
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
|
Loading…
Reference in New Issue