Style extract_tts_spectrogram.py

pull/800/head
Eren Gölge 2021-09-10 08:21:21 +00:00
parent 1de010acd4
commit 3c740d4893
1 changed files with 13 additions and 9 deletions

View File

@ -76,14 +76,14 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data['text']
text_lengths = data['text_lengths']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
item_idx = data['item_idxs']
d_vectors = data['d_vectors']
speaker_ids = data['speaker_ids']
attn_mask = data['attns']
text_input = data["text"]
text_lengths = data["text_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]
d_vectors = data["d_vectors"]
speaker_ids = data["speaker_ids"]
attn_mask = data["attns"]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
@ -132,7 +132,11 @@ def inference(
elif d_vectors is not None:
speaker_c = d_vectors
outputs = model.inference_with_MAS(
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
)
model_output = outputs["model_outputs"]
model_output = model_output.transpose(1, 2).detach().cpu().numpy()