Fix checkpointing GAN models (#1641)

* checkpoint sae step crash fix

* checkpoint save step crash fix

* Update gan.py

updated requested changes

* crash fix
pull/1469/head^2
manmay nakhashi 2022-06-22 15:37:46 +05:30 committed by GitHub
parent 00e67092d8
commit 577ec406f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -185,8 +185,7 @@ class GAN(BaseVocoder):
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict
@staticmethod
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
"""Logging shared by the training and evaluation.
Args:
@ -198,7 +197,7 @@ class GAN(BaseVocoder):
Returns:
Tuple[Dict, Dict]: log figures and audio samples.
"""
y_hat = outputs[0]["model_outputs"]
y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
y = batch["waveform"]
figures = plot_results(y_hat, y, ap, name)
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()