mirror of https://github.com/coqui-ai/TTS.git
update `extract_tts_spec...` using `SpeakerManager`
parent
830306d2fd
commit
667bb708b6
|
@ -13,7 +13,7 @@ from TTS.config import load_config
|
|||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
@ -39,7 +39,9 @@ def setup_loader(ap, r, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
|
||||
speaker_mapping=speaker_manager.speaker_ids
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
@ -91,7 +93,7 @@ def format_data(data):
|
|||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
|
@ -134,12 +136,11 @@ def inference(
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask=None,
|
||||
speaker_ids=None,
|
||||
speaker_embeddings=None,
|
||||
):
|
||||
if model_name == "glow_tts":
|
||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
# mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
speaker_c = None
|
||||
if speaker_ids is not None:
|
||||
speaker_c = speaker_ids
|
||||
|
@ -147,9 +148,9 @@ def inference(
|
|||
speaker_c = speaker_embeddings
|
||||
|
||||
outputs = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
text_input, text_lengths, mel_input, mel_lengths, cond_input={"x_vectors": speaker_c}
|
||||
)
|
||||
model_output = outputs['model_outputs']
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
|
@ -193,7 +194,7 @@ def extract_spectrograms(
|
|||
speaker_embeddings,
|
||||
_,
|
||||
_,
|
||||
attn_mask,
|
||||
_,
|
||||
item_idx,
|
||||
) = format_data(data)
|
||||
|
||||
|
@ -205,7 +206,6 @@ def extract_spectrograms(
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
)
|
||||
|
@ -242,7 +242,7 @@ def extract_spectrograms(
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_manager
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
@ -260,10 +260,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None)
|
||||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
model = setup_model(num_chars, speaker_manager.num_speakers, c, speaker_embedding_dim=speaker_manager.x_vector_dim)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
|
|
Loading…
Reference in New Issue