mirror of https://github.com/coqui-ai/TTS.git
Fix test sentences synthesis
parent
93a74cbb71
commit
32974dd6a9
|
@ -764,11 +764,11 @@ class Trainer:
|
|||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.eval_loader.load_test_samples):
|
||||
if hasattr(self.eval_loader, "load_test_samples"):
|
||||
samples = self.eval_loader.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(samples)
|
||||
else:
|
||||
figures, audios = self.model.test_run()
|
||||
figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
|
||||
|
@ -790,7 +790,7 @@ class Trainer:
|
|||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
|
||||
self.test_run()
|
||||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
|
|
|
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
return loader
|
||||
|
||||
def test_run(self) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
|
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
|
|||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
|
Loading…
Reference in New Issue