Make style

pull/2870/head
Eren G??lge 2023-08-11 12:55:23 +02:00
parent 9a8352b8da
commit 37b558ccb9
4 changed files with 16 additions and 21 deletions

View File

@ -191,10 +191,10 @@ lock = Lock()
@app.route("/api/tts", methods=["GET", "POST"])
def tts():
with lock:
text = request.headers.get('text') or request.values.get("text", "")
speaker_idx = request.headers.get('speaker-id') or request.values.get("speaker_id", "")
language_idx = request.headers.get('language-id') or request.values.get("language_id", "")
style_wav = request.headers.get('style-wav') or request.values.get("style_wav", "")
text = request.headers.get("text") or request.values.get("text", "")
speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "")
language_idx = request.headers.get("language-id") or request.values.get("language_id", "")
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(f" > Model input: {text}")

View File

@ -102,7 +102,7 @@ class HubertTokenizer(nn.Module):
model_zip.close()
@staticmethod
def load_from_checkpoint(path, map_location = None):
def load_from_checkpoint(path, map_location=None):
old = True
with ZipFile(path) as model_zip:
filesMatch = [file for file in model_zip.namelist() if file.endswith("/.info")]

View File

@ -184,7 +184,6 @@ def generate_text_semantic(
Returns:
np.ndarray: The generated semantic tokens.
"""
print(f"history_prompt in gen: {history_prompt}")
assert isinstance(text, str)
text = _normalize_whitespace(text)
assert len(text.strip()) > 0

View File

@ -1849,17 +1849,17 @@ class Vits(BaseTTS):
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales)
input_names = ["input", "input_lengths", "scales"]
if self.num_speakers > 0:
speaker_id = torch.LongTensor([0])
dummy_input += (speaker_id, )
dummy_input += (speaker_id,)
input_names.append("sid")
if hasattr(self, 'num_languages') and self.num_languages > 0 and self.embedded_language_dim > 0:
if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
dummy_input += (language_id, )
input_names.append("langid")
dummy_input += (language_id,)
input_names.append("langid")
# export to ONNX
torch.onnx.export(
model=self,
@ -1875,7 +1875,7 @@ class Vits(BaseTTS):
"output": {0: "batch_size", 1: "time1", 2: "time2"},
},
)
# rollback
self.forward = _forward
if training:
@ -1885,7 +1885,7 @@ class Vits(BaseTTS):
def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
providers = [
"CPUExecutionProvider"
if cuda is False
@ -1913,16 +1913,12 @@ class Vits(BaseTTS):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)
input_params = {
"input": x,
"input_lengths": x_lengths,
"scales": scales
}
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
if not speaker_id is None:
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
if not language_id is None:
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
audio = self.onnx_sess.run(
["output"],
input_params,