mirror of https://github.com/coqui-ai/TTS.git
Fix speaker encoder init
parent
cc514b36bb
commit
56378b12f7
|
@ -11,7 +11,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.training import init_training
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_eval
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_model(c)
|
||||
model = setup_speaker_encoder_model(c)
|
||||
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
|
|
|
@ -100,7 +100,15 @@ if args.vocoder_path is not None:
|
|||
|
||||
# load models
|
||||
synthesizer = Synthesizer(
|
||||
model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||
tts_checkpoint=model_path,
|
||||
tts_config_path=config_path,
|
||||
tts_speakers_file=speakers_file_path,
|
||||
tts_languages_file=None,
|
||||
vocoder_checkpoint=vocoder_path,
|
||||
vocoder_config=vocoder_config_path,
|
||||
encoder_checkpoint="",
|
||||
encoder_config="",
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||
|
|
|
@ -170,7 +170,7 @@ class Synthesizer(object):
|
|||
|
||||
def _init_speaker_encoder(self, speaker_manager):
|
||||
"""Initialize the SpeakerEncoder"""
|
||||
if self.encoder_checkpoint is not None:
|
||||
if self.encoder_checkpoint:
|
||||
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
return speaker_manager
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
config.audio.resample = True
|
||||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_model(config)
|
||||
model = setup_speaker_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
|
|
Loading…
Reference in New Issue