Fix test sentences synthesis

remotes/WeberJulian/fix-test-sentences
WeberJulian 2021-07-13 16:04:42 +02:00
parent 93a74cbb71
commit 32974dd6a9
2 changed files with 9 additions and 9 deletions

View File

@ -764,11 +764,11 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader.load_test_samples):
if hasattr(self.eval_loader, "load_test_samples"):
samples = self.eval_loader.load_test_samples(1)
figures, audios = self.model.test_run(samples)
else:
figures, audios = self.model.test_run()
figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
@ -790,7 +790,7 @@ class Trainer:
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values

View File

@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
)
return loader
def test_run(self) -> Tuple[Dict, Dict]:
def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs()
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
self.model,
self,
sen,
self.config,
self.use_cuda,
self.ap,
use_cuda,
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
return test_figures, test_audios