code styling

pull/441/head
Eren Gölge 2021-04-16 16:01:40 +02:00
parent af7baa3387
commit aadb2106ec
4 changed files with 32 additions and 44 deletions

View File

@ -7,10 +7,10 @@ from pathlib import Path
from flask import Flask, render_template, request, send_file
from TTS.utils.generic_utils import style_wav_uri_to_dict
from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.generic_utils import style_wav_uri_to_dict
def create_argparser():
@ -90,11 +90,9 @@ app = Flask(__name__)
@app.route("/")
def index():
return render_template(
'index.html',
show_details=args.show_details,
use_speaker_embedding=use_speaker_embedding,
use_gst = use_gst
)
"index.html", show_details=args.show_details, use_speaker_embedding=use_speaker_embedding, use_gst=use_gst
)
@app.route("/details")
def details():
@ -115,9 +113,9 @@ def details():
@app.route("/api/tts", methods=["GET"])
def tts():
text = request.args.get('text')
speaker_json_key = request.args.get('speaker', "")
style_wav = request.args.get('style-wav', "")
text = request.args.get("text")
speaker_json_key = request.args.get("speaker", "")
style_wav = request.args.get("style-wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))

View File

@ -66,17 +66,19 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
if 'tacotron' in CONFIG.model.lower():
if "tacotron" in CONFIG.model.lower():
if not CONFIG.use_gst:
style_mel = None
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
elif "glow" in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"):
# distributed model
@ -84,9 +86,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs, inputs_lengths, g=speaker_embedding_g
)
else:
postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_embedding_g
)
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
@ -95,13 +95,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"):
# distributed model
postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_embedding_g
)
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g)
else:
postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_embedding_g
)
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None

View File

@ -162,7 +162,9 @@ def check_argument(
is_valid = True
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
elif val_type:
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
assert (
isinstance(c[name], val_type) or c[name] is None
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
@ -176,7 +178,7 @@ def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
Union[str, dict]: path to file (str) or gst style (dict)
"""
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
return style_wav # style_wav is a .wav file located on the server
return style_wav # style_wav is a .wav file located on the server
style_wav = json.loads(style_wav)
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}

View File

@ -68,14 +68,11 @@ class Synthesizer(object):
def _get_segmenter(lang: str):
return pysbd.Segmenter(language=lang, clean=True)
def _load_speakers(self, speaker_file: str) -> None:
print("Loading speakers ...")
self.tts_speakers = load_speaker_mapping(speaker_file)
self.num_speakers = len(self.tts_speakers)
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
"embedding"
])
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"])
def _load_speaker_embedding(self, speaker_json_key: str = ""):
@ -86,14 +83,14 @@ class Synthesizer(object):
if speaker_json_key != "":
assert self.tts_speakers
assert speaker_json_key in self.tts_speakers, f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
assert (
speaker_json_key in self.tts_speakers
), f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
return speaker_embedding
def _load_tts(
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
) -> None:
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
# pylint: disable=global-statement
global symbols, phonemes
@ -111,20 +108,19 @@ class Synthesizer(object):
self.input_size = len(symbols)
if self.tts_config.use_speaker_embedding is True:
self._load_speakers(self.tts_config.get('external_speaker_embedding_file', self.tts_speakers_file))
self._load_speakers(self.tts_config.get("external_speaker_embedding_file", self.tts_speakers_file))
self.tts_model = setup_model(
self.input_size,
num_speakers=self.num_speakers,
c=self.tts_config,
speaker_embedding_dim=self.speaker_embedding_dim)
speaker_embedding_dim=self.speaker_embedding_dim,
)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
@ -140,7 +136,7 @@ class Synthesizer(object):
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
def tts(self, text: str, speaker_json_key: str = "", style_wav=None) -> List[int]:
start_time = time.time()
wavs = []
sens = self._split_into_sentences(text)
@ -178,13 +174,9 @@ class Synthesizer(object):
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(
scale_factor, vocoder_input
)
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(
0
) # pylint: disable=not-callable
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))