Draft ONNX export for VITS (#2563)

* Draft ONNX export for VITS

Could not get it work to output variable length sequence

* Fixup for onnx constant output

* Make style

* Remove commented code
pull/2617/head
Eren Gölge 2023-05-16 01:07:56 +02:00 committed by GitHub
parent 16c9df0dfe
commit 4de797bb11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 111 additions and 5 deletions

View File

@ -1758,6 +1758,115 @@ class Vits(BaseTTS):
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
"""Export model to ONNX format for inference
Args:
output_path (str): Path to save the exported model.
verbose (bool): Print verbose information. Defaults to True.
"""
# rollback values
_forward = self.forward
disc = self.disc
training = self.training
# set export mode
self.disc = None
self.eval()
def onnx_inference(text, text_lengths, scales, sid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_dp = scales[2]
self.noise_scale = noise_scale
self.length_scale = length_scale
self.noise_scale_dp = noise_scale_dp
return self.inference(
text,
aux_input={
"x_lengths": text_lengths,
"d_vectors": None,
"speaker_ids": sid,
"language_ids": None,
"durations": None,
},
)["model_outputs"]
self.forward = onnx_inference
# set dummy inputs
dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sepaker_id = None
if self.num_speakers > 1:
sepaker_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
# export to ONNX
torch.onnx.export(
model=self,
args=dummy_input,
opset_version=15,
f=output_path,
verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
"input_lengths": {0: "batch_size"},
"output": {0: "batch_size", 1: "time1", 2: "time2"},
},
)
# rollback
self.forward = _forward
if training:
self.train()
self.disc = disc
def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
sess_options = ort.SessionOptions()
self.onnx_sess = ort.InferenceSession(
model_path,
sess_options=sess_options,
providers=providers,
)
def inference_onnx(self, x, x_lengths=None):
"""ONNX inference (only single speaker models are supported)
TODO: implement multi speaker support.
"""
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
if x_lengths is None:
x_lengths = np.array([x.shape[1]], dtype=np.int64)
if isinstance(x_lengths, torch.Tensor):
x_lengths = x_lengths.cpu().numpy()
scales = np.array(
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)
audio = self.onnx_sess.run(
["output"],
{
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": None,
},
)
return audio[0][0]
##################################
# VITS CHARACTERS

View File

@ -50,11 +50,10 @@ def sequence_mask(sequence_length, max_len=None):
- mask: :math:`[B, T_max]`
"""
if max_len is None:
max_len = sequence_length.data.max()
max_len = sequence_length.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
return mask
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
@ -158,10 +157,8 @@ def generate_path(duration, mask):
- mask: :math:'[B, T_en, T_de]`
- path: :math:`[B, T_en, T_de]`
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)