diff --git a/TTS/trainer.py b/TTS/trainer.py index c56be140..bbd9665a 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -764,11 +764,11 @@ class Trainer: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): - if hasattr(self.eval_loader.load_test_samples): + if hasattr(self.eval_loader, "load_test_samples"): samples = self.eval_loader.load_test_samples(1) figures, audios = self.model.test_run(samples) else: - figures, audios = self.model.test_run() + figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, figures) @@ -790,7 +790,7 @@ class Trainer: self.train_epoch() if self.config.run_eval: self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank < 0: + if epoch >= self.config.test_delay_epochs and self.args.rank <= 0: self.test_run() self.c_logger.print_epoch_end( epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 2ec268d6..64c0ba6f 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -200,7 +200,7 @@ class BaseTTS(BaseModel): ) return loader - def test_run(self) -> Tuple[Dict, Dict]: + def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -212,14 +212,14 @@ class BaseTTS(BaseModel): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = self._get_aux_inputs() + aux_inputs = self.get_aux_input() for idx, sen in enumerate(test_sentences): wav, alignment, model_outputs, _ = synthesis( - self.model, + self, sen, self.config, - self.use_cuda, - self.ap, + use_cuda, + ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], @@ -229,6 +229,6 @@ class BaseTTS(BaseModel): ).values() test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) + test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) return test_figures, test_audios