update `extract_tts_spectrograms` for the new model API

pull/602/head
Eren Gölge 2021-05-27 15:18:36 +02:00
parent c673eb8ef8
commit 830306d2fd
1 changed files with 6 additions and 4 deletions

View File

@ -146,20 +146,22 @@ def inference(
elif speaker_embeddings is not None:
speaker_c = speaker_embeddings
model_output, *_ = model.inference_with_MAS(
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
)
model_output = outputs['model_outputs']
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
elif "tacotron" in model_name:
_, postnet_outputs, *_ = model(
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': speaker_embeddings}
outputs = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
cond_input
)
postnet_outputs = outputs['model_outputs']
# normalize tacotron output
if model_name == "tacotron":
mel_specs = []