mirror of https://github.com/coqui-ai/TTS.git
Fix train_tts.py and uncomment code (#1051)
* Fix SE loading and language embedding logic * remove trailing white space * Uncomment resmapling code for SCLpull/1027/head
parent
58c38de58d
commit
e1accb6e28
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
|
@ -53,15 +54,22 @@ def main():
|
|||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
|
||||
speaker_manager = SpeakerManager(
|
||||
d_vectors_file_path=config.model_args.d_vector_file,
|
||||
encoder_model_path=config.model_args.speaker_encoder_model_path,
|
||||
encoder_config_path=config.model_args.speaker_encoder_config_path,
|
||||
use_cuda=torch.cuda.is_available(),
|
||||
)
|
||||
else:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
else:
|
||||
speaker_manager = None
|
||||
|
||||
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
|
||||
if check_config_and_model_args(config, "use_language_embedding", True):
|
||||
language_manager = LanguageManager(config=config)
|
||||
if hasattr(config, "model_args"):
|
||||
config.model_args.num_languages = language_manager.num_languages
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
# import torchaudio
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
@ -419,21 +419,12 @@ class Vits(BaseTTS):
|
|||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
# TODO: change this with torchaudio Resample
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
||||
self.config.audio["sample_rate"],
|
||||
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
)
|
||||
# pylint: disable=W0101,W0105
|
||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
else:
|
||||
self.audio_transform = None
|
||||
"""
|
||||
else:
|
||||
self.audio_transform = None
|
||||
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -458,6 +449,7 @@ class Vits(BaseTTS):
|
|||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
|
||||
if self.args.use_language_embedding and self.language_manager:
|
||||
print(" > initialization of language-embedding layers.")
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
self.embedded_language_dim = self.args.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||
|
@ -643,8 +635,8 @@ class Vits(BaseTTS):
|
|||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
# pylint: disable=W0105
|
||||
"""if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)"""
|
||||
if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)
|
||||
|
||||
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue