mirror of https://github.com/coqui-ai/TTS.git
Fix `test_run` for DDP
parent
7c0d564965
commit
c8bbcdfd07
|
@ -829,9 +829,15 @@ class Trainer:
|
||||||
|
|
||||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
if self.num_gpus > 1:
|
||||||
|
figures, audios = self.model.module.test_run(self.ap, samples, None)
|
||||||
|
else:
|
||||||
|
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||||
else:
|
else:
|
||||||
figures, audios = self.model.test_run(self.ap)
|
if self.num_gpus > 1:
|
||||||
|
figures, audios = self.model.module.test_run(self.ap)
|
||||||
|
else:
|
||||||
|
figures, audios = self.model.test_run(self.ap)
|
||||||
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||||
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
self.dashboard_logger.test_figures(self.total_steps_done, figures)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue