mirror of https://github.com/coqui-ai/TTS.git
Minors bug fixes on VITS/YourTTS and inference (#2054)
* Set the right device to the speaker encoder * Bug fix on inference list_language_idxs parameter * Bug fix on speaker encoder resample audio transformpull/2066/head
parent
5f5d441ee5
commit
f3b947e706
|
@ -331,7 +331,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.ids)
|
||||
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
|
|
|
@ -721,6 +721,10 @@ class Vits(BaseTTS):
|
|||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
@ -758,17 +762,12 @@ class Vits(BaseTTS):
|
|||
|
||||
if (
|
||||
hasattr(self.speaker_manager.encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
|
||||
and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"]
|
||||
):
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
||||
)
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -811,6 +810,13 @@ class Vits(BaseTTS):
|
|||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
) # pylint: disable=W0201
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
"""Freeze layers at the beginning of an epoch"""
|
||||
self._freeze_layers()
|
||||
# set the device of speaker encoder
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device)
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
"""Reinit layes if needed"""
|
||||
if self.args.reinit_DP:
|
||||
|
@ -1231,8 +1237,6 @@ class Vits(BaseTTS):
|
|||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
|
||||
self._freeze_layers()
|
||||
|
||||
spec_lens = batch["spec_lens"]
|
||||
|
||||
if optimizer_idx == 0:
|
||||
|
|
Loading…
Reference in New Issue