From aadb2106ec41d5365260a6d7b6a3ecdb03f3b405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 16 Apr 2021 16:01:40 +0200 Subject: [PATCH] code styling --- TTS/server/server.py | 16 +++++++--------- TTS/tts/utils/synthesis.py | 22 +++++++++------------- TTS/utils/generic_utils.py | 8 +++++--- TTS/utils/synthesizer.py | 30 +++++++++++------------------- 4 files changed, 32 insertions(+), 44 deletions(-) diff --git a/TTS/server/server.py b/TTS/server/server.py index 2275e825..5239f531 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -7,10 +7,10 @@ from pathlib import Path from flask import Flask, render_template, request, send_file +from TTS.utils.generic_utils import style_wav_uri_to_dict from TTS.utils.io import load_config from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -from TTS.utils.generic_utils import style_wav_uri_to_dict def create_argparser(): @@ -90,11 +90,9 @@ app = Flask(__name__) @app.route("/") def index(): return render_template( - 'index.html', - show_details=args.show_details, - use_speaker_embedding=use_speaker_embedding, - use_gst = use_gst - ) + "index.html", show_details=args.show_details, use_speaker_embedding=use_speaker_embedding, use_gst=use_gst + ) + @app.route("/details") def details(): @@ -115,9 +113,9 @@ def details(): @app.route("/api/tts", methods=["GET"]) def tts(): - text = request.args.get('text') - speaker_json_key = request.args.get('speaker', "") - style_wav = request.args.get('style-wav', "") + text = request.args.get("text") + speaker_json_key = request.args.get("speaker", "") + style_wav = request.args.get("style-wav", "") style_wav = style_wav_uri_to_dict(style_wav) print(" > Model input: {}".format(text)) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index d449a678..f2cfbd43 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -66,17 +66,19 @@ def compute_style_mel(style_wav, ap, cuda=False): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings - if 'tacotron' in CONFIG.model.lower(): + if "tacotron" in CONFIG.model.lower(): if not CONFIG.use_gst: style_mel = None if truncated: decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) else: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - elif 'glow' in CONFIG.model.lower(): + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) + elif "glow" in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable if hasattr(model, "module"): # distributed model @@ -84,9 +86,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs, inputs_lengths, g=speaker_embedding_g ) else: - postnet_output, _, _, _, alignments, _, _ = model.inference( - inputs, inputs_lengths, g=speaker_embedding_g - ) + postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None @@ -95,13 +95,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable if hasattr(model, "module"): # distributed model - postnet_output, alignments = model.module.inference( - inputs, inputs_lengths, g=speaker_embedding_g - ) + postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g) else: - postnet_output, alignments = model.inference( - inputs, inputs_lengths, g=speaker_embedding_g - ) + postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index ffca1253..4de25300 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -162,7 +162,9 @@ def check_argument( is_valid = True assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" elif val_type: - assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + assert ( + isinstance(c[name], val_type) or c[name] is None + ), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: @@ -176,7 +178,7 @@ def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: Union[str, dict]: path to file (str) or gst style (dict) """ if os.path.isfile(style_wav) and style_wav.endswith(".wav"): - return style_wav # style_wav is a .wav file located on the server + return style_wav # style_wav is a .wav file located on the server style_wav = json.loads(style_wav) - return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} + return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 84ca6111..95330dca 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -68,14 +68,11 @@ class Synthesizer(object): def _get_segmenter(lang: str): return pysbd.Segmenter(language=lang, clean=True) - def _load_speakers(self, speaker_file: str) -> None: print("Loading speakers ...") self.tts_speakers = load_speaker_mapping(speaker_file) self.num_speakers = len(self.tts_speakers) - self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][ - "embedding" - ]) + self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]) def _load_speaker_embedding(self, speaker_json_key: str = ""): @@ -86,14 +83,14 @@ class Synthesizer(object): if speaker_json_key != "": assert self.tts_speakers - assert speaker_json_key in self.tts_speakers, f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'" + assert ( + speaker_json_key in self.tts_speakers + ), f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'" speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"] return speaker_embedding - def _load_tts( - self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool - ) -> None: + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: # pylint: disable=global-statement global symbols, phonemes @@ -111,20 +108,19 @@ class Synthesizer(object): self.input_size = len(symbols) if self.tts_config.use_speaker_embedding is True: - self._load_speakers(self.tts_config.get('external_speaker_embedding_file', self.tts_speakers_file)) + self._load_speakers(self.tts_config.get("external_speaker_embedding_file", self.tts_speakers_file)) self.tts_model = setup_model( self.input_size, num_speakers=self.num_speakers, c=self.tts_config, - speaker_embedding_dim=self.speaker_embedding_dim) + speaker_embedding_dim=self.speaker_embedding_dim, + ) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() - - def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: self.vocoder_config = load_config(model_config) self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"]) @@ -140,7 +136,7 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path, self.output_sample_rate) - def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]: + def tts(self, text: str, speaker_json_key: str = "", style_wav=None) -> List[int]: start_time = time.time() wavs = [] sens = self._split_into_sentences(text) @@ -178,13 +174,9 @@ class Synthesizer(object): ] if scale_factor[1] != 1: print(" > interpolating tts model output.") - vocoder_input = interpolate_vocoder_input( - scale_factor, vocoder_input - ) + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: - vocoder_input = torch.tensor(vocoder_input).unsqueeze( - 0 - ) # pylint: disable=not-callable + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] waveform = self.vocoder_model.inference(vocoder_input.to(device_type))