diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index deac7fc5..26a4b2f4 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -13,7 +13,7 @@ from TTS.config import load_config from TTS.tts.datasets import load_meta_data from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.speakers import parse_speakers +from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters @@ -39,7 +39,9 @@ def setup_loader(ap, r, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=False, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None, + speaker_mapping=speaker_manager.speaker_ids + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, ) if c.use_phonemes and c.compute_input_seq_cache: @@ -91,7 +93,7 @@ def format_data(data): speaker_embeddings = data[8] speaker_ids = None else: - speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] + speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) speaker_embeddings = None else: @@ -134,12 +136,11 @@ def inference( text_lengths, mel_input, mel_lengths, - attn_mask=None, speaker_ids=None, speaker_embeddings=None, ): if model_name == "glow_tts": - mel_input = mel_input.permute(0, 2, 1) # B x D x T + # mel_input = mel_input.permute(0, 2, 1) # B x D x T speaker_c = None if speaker_ids is not None: speaker_c = speaker_ids @@ -147,9 +148,9 @@ def inference( speaker_c = speaker_embeddings outputs = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": speaker_c} ) - model_output = outputs['model_outputs'] + model_output = outputs["model_outputs"] model_output = model_output.transpose(1, 2).detach().cpu().numpy() elif "tacotron" in model_name: @@ -193,7 +194,7 @@ def extract_spectrograms( speaker_embeddings, _, _, - attn_mask, + _, item_idx, ) = format_data(data) @@ -205,7 +206,6 @@ def extract_spectrograms( text_lengths, mel_input, mel_lengths, - attn_mask, speaker_ids, speaker_embeddings, ) @@ -242,7 +242,7 @@ def extract_spectrograms( def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined - global meta_data, symbols, phonemes, model_characters, speaker_mapping + global meta_data, symbols, phonemes, model_characters, speaker_manager # Audio processor ap = AudioProcessor(**c.audio) @@ -260,10 +260,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data = meta_data_train + meta_data_eval # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None) + speaker_manager = get_speaker_manager(c, args, meta_data_train) # setup model - model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + model = setup_model(num_chars, speaker_manager.num_speakers, c, speaker_embedding_dim=speaker_manager.x_vector_dim) # restore model checkpoint = torch.load(args.checkpoint_path, map_location="cpu")