Minors bug fixes on VITS/YourTTS and inference (#2054)

* Set the right device to the speaker encoder

* Bug fix on inference list_language_idxs parameter

* Bug fix on speaker encoder resample audio transform
pull/2066/head
Edresson Casanova 2022-10-06 17:23:54 -03:00 committed by GitHub
parent 5f5d441ee5
commit f3b947e706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 10 deletions

View File

@ -331,7 +331,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.ids)
print(synthesizer.tts_model.language_manager.name_to_id)
return
# check the arguments against a multi-speaker model.

View File

@ -721,6 +721,10 @@ class Vits(BaseTTS):
use_spectral_norm=self.args.use_spectral_norm_disriminator,
)
@property
def device(self):
return next(self.parameters()).device
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
@ -758,17 +762,12 @@ class Vits(BaseTTS):
if (
hasattr(self.speaker_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
@ -811,6 +810,13 @@ class Vits(BaseTTS):
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
) # pylint: disable=W0201
def on_epoch_start(self, trainer): # pylint: disable=W0613
"""Freeze layers at the beginning of an epoch"""
self._freeze_layers()
# set the device of speaker encoder
if self.args.use_speaker_encoder_as_loss:
self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device)
def on_init_end(self, trainer): # pylint: disable=W0613
"""Reinit layes if needed"""
if self.args.reinit_DP:
@ -1231,8 +1237,6 @@ class Vits(BaseTTS):
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
self._freeze_layers()
spec_lens = batch["spec_lens"]
if optimizer_idx == 0: