mirror of https://github.com/coqui-ai/TTS.git
tflite utils and tflite backend for utils/synthesis.py
parent
d41cb7fe47
commit
de3c5c95d0
|
@ -0,0 +1,44 @@
|
|||
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
||||
|
||||
import tensorflow as tf
|
||||
import argparse
|
||||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.text.symbols import symbols, phonemes, make_symbols
|
||||
from TTS.tf.utils.generic_utils import setup_model
|
||||
from TTS.tf.utils.io import load_checkpoint
|
||||
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tf_model',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path to tflite output binary.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set constants
|
||||
CONFIG = load_config(args.config_path)
|
||||
|
||||
# load the model
|
||||
c = CONFIG
|
||||
num_speakers = 0
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
|
||||
model.build_inference()
|
||||
model = load_checkpoint(model, args.tf_model)
|
||||
model.decoder.set_max_decoder_steps(1000)
|
||||
|
||||
# create tflite model
|
||||
tflite_model = convert_tacotron2_to_tflite(model)
|
||||
|
||||
# save tflite binary
|
||||
with open(args.output_path, 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
|
|
@ -0,0 +1,42 @@
|
|||
import pickle
|
||||
import datetime
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
state = {
|
||||
'model': model.weights,
|
||||
'optimizer': optimizer,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, 'wb'))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
try:
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
except KeyError:
|
||||
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
||||
layer_name = f"{class_name}/{layer_name}"
|
||||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if 'r' in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint['r'])
|
||||
return model
|
||||
|
||||
|
||||
def load_tflite_model(tflite_path):
|
||||
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||
tflite_model.allocate_tensors()
|
||||
return tflite_model
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def convert_tacotron2_to_tflite(model):
|
||||
tacotron2_concrete_function = model.inference_tflite.get_concrete_function()
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
||||
[tacotron2_concrete_function]
|
||||
)
|
||||
converter.experimental_new_converter = True
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||
tf.lite.OpsSet.SELECT_TF_OPS]
|
||||
tflite_model = converter.convert()
|
||||
return tflite_model
|
||||
|
||||
|
||||
def load_tflite_model(tflite_path):
|
||||
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||
tflite_model.allocate_tensors()
|
||||
return tflite_model
|
|
@ -70,6 +70,31 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
|
|||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst and style_mel is not None:
|
||||
raise NotImplementedError(' [!] GST inference not implemented for TfLite')
|
||||
if truncated:
|
||||
raise NotImplementedError(' [!] Truncated inference not implemented for TfLite')
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite')
|
||||
# get input and output details
|
||||
input_details = model.get_input_details()
|
||||
output_details = model.get_output_details()
|
||||
# reshape input tensor for the new input shape
|
||||
model.resize_tensor_input(input_details[0]['index'], inputs.shape)
|
||||
model.allocate_tensors()
|
||||
detail = input_details[0]
|
||||
input_shape = detail['shape']
|
||||
model.set_tensor(detail['index'], inputs)
|
||||
# run the model
|
||||
model.invoke()
|
||||
# collect outputs
|
||||
decoder_output = model.get_tensor(output_details[0]['index'])
|
||||
postnet_output = model.get_tensor(output_details[1]['index'])
|
||||
# tflite model only returns feature frames
|
||||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
|
@ -86,6 +111,12 @@ def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
|||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_tflite(postnet_output, decoder_output):
|
||||
postnet_output = postnet_output[0]
|
||||
decoder_output = decoder_output[0]
|
||||
return postnet_output, decoder_output
|
||||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
return wav[:ap.find_endpoint(wav)]
|
||||
|
||||
|
@ -164,22 +195,31 @@ def synthesis(model,
|
|||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
else:
|
||||
elif backend == 'tf':
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
elif backend == 'tflite':
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == 'torch':
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
else:
|
||||
elif backend == 'tf':
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == 'tflite':
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(
|
||||
postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
|
|
Loading…
Reference in New Issue