diff --git a/TTS/trainer.py b/TTS/trainer.py index 32e561d6..d3d66ab2 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -829,9 +829,15 @@ class Trainer: if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) - figures, audios = self.model.test_run(self.ap, samples, None) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap, samples, None) + else: + figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run(self.ap) + if self.num_gpus > 1: + figures, audios = self.model.module.test_run(self.ap) + else: + figures, audios = self.model.test_run(self.ap) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_figures(self.total_steps_done, figures)