From 3c740d48936d379c98c0f2c453a5fb8c8de87af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 08:21:21 +0000 Subject: [PATCH] Style extract_tts_spectrogram.py --- TTS/bin/extract_tts_spectrograms.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 112ef046..681fcc36 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -76,14 +76,14 @@ def set_filename(wav_path, out_path): def format_data(data): # setup input data - text_input = data['text'] - text_lengths = data['text_lengths'] - mel_input = data['mel'] - mel_lengths = data['mel_lengths'] - item_idx = data['item_idxs'] - d_vectors = data['d_vectors'] - speaker_ids = data['speaker_ids'] - attn_mask = data['attns'] + text_input = data["text"] + text_lengths = data["text_lengths"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + item_idx = data["item_idxs"] + d_vectors = data["d_vectors"] + speaker_ids = data["speaker_ids"] + attn_mask = data["attns"] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) @@ -132,7 +132,11 @@ def inference( elif d_vectors is not None: speaker_c = d_vectors outputs = model.inference_with_MAS( - text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids} + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, ) model_output = outputs["model_outputs"] model_output = model_output.transpose(1, 2).detach().cpu().numpy()