mirror of https://github.com/coqui-ai/TTS.git
Draft ONNX export for VITS (#2563)
* Draft ONNX export for VITS Could not get it work to output variable length sequence * Fixup for onnx constant output * Make style * Remove commented codepull/2617/head
parent
16c9df0dfe
commit
4de797bb11
|
@ -1758,6 +1758,115 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
|
||||
"""Export model to ONNX format for inference
|
||||
|
||||
Args:
|
||||
output_path (str): Path to save the exported model.
|
||||
verbose (bool): Print verbose information. Defaults to True.
|
||||
"""
|
||||
|
||||
# rollback values
|
||||
_forward = self.forward
|
||||
disc = self.disc
|
||||
training = self.training
|
||||
|
||||
# set export mode
|
||||
self.disc = None
|
||||
self.eval()
|
||||
|
||||
def onnx_inference(text, text_lengths, scales, sid=None):
|
||||
noise_scale = scales[0]
|
||||
length_scale = scales[1]
|
||||
noise_scale_dp = scales[2]
|
||||
self.noise_scale = noise_scale
|
||||
self.length_scale = length_scale
|
||||
self.noise_scale_dp = noise_scale_dp
|
||||
return self.inference(
|
||||
text,
|
||||
aux_input={
|
||||
"x_lengths": text_lengths,
|
||||
"d_vectors": None,
|
||||
"speaker_ids": sid,
|
||||
"language_ids": None,
|
||||
"durations": None,
|
||||
},
|
||||
)["model_outputs"]
|
||||
|
||||
self.forward = onnx_inference
|
||||
|
||||
# set dummy inputs
|
||||
dummy_input_length = 100
|
||||
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
|
||||
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
||||
sepaker_id = None
|
||||
if self.num_speakers > 1:
|
||||
sepaker_id = torch.LongTensor([0])
|
||||
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
||||
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
|
||||
|
||||
# export to ONNX
|
||||
torch.onnx.export(
|
||||
model=self,
|
||||
args=dummy_input,
|
||||
opset_version=15,
|
||||
f=output_path,
|
||||
verbose=verbose,
|
||||
input_names=["input", "input_lengths", "scales", "sid"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size", 1: "phonemes"},
|
||||
"input_lengths": {0: "batch_size"},
|
||||
"output": {0: "batch_size", 1: "time1", 2: "time2"},
|
||||
},
|
||||
)
|
||||
|
||||
# rollback
|
||||
self.forward = _forward
|
||||
if training:
|
||||
self.train()
|
||||
self.disc = disc
|
||||
|
||||
def load_onnx(self, model_path: str, cuda=False):
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
|
||||
sess_options = ort.SessionOptions()
|
||||
self.onnx_sess = ort.InferenceSession(
|
||||
model_path,
|
||||
sess_options=sess_options,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
def inference_onnx(self, x, x_lengths=None):
|
||||
"""ONNX inference (only single speaker models are supported)
|
||||
|
||||
TODO: implement multi speaker support.
|
||||
"""
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.cpu().numpy()
|
||||
|
||||
if x_lengths is None:
|
||||
x_lengths = np.array([x.shape[1]], dtype=np.int64)
|
||||
|
||||
if isinstance(x_lengths, torch.Tensor):
|
||||
x_lengths = x_lengths.cpu().numpy()
|
||||
scales = np.array(
|
||||
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
||||
dtype=np.float32,
|
||||
)
|
||||
audio = self.onnx_sess.run(
|
||||
["output"],
|
||||
{
|
||||
"input": x,
|
||||
"input_lengths": x_lengths,
|
||||
"scales": scales,
|
||||
"sid": None,
|
||||
},
|
||||
)
|
||||
return audio[0][0]
|
||||
|
||||
|
||||
##################################
|
||||
# VITS CHARACTERS
|
||||
|
|
|
@ -50,11 +50,10 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
- mask: :math:`[B, T_max]`
|
||||
"""
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
max_len = sequence_length.max()
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
return mask
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
|
||||
|
@ -158,10 +157,8 @@ def generate_path(duration, mask):
|
|||
- mask: :math:'[B, T_en, T_de]`
|
||||
- path: :math:`[B, T_en, T_de]`
|
||||
"""
|
||||
device = duration.device
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
|
|
Loading…
Reference in New Issue