From de3c5c95d08f9d2b19e369891a0922002b7d8d95 Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 9 Jul 2020 11:45:09 +0200 Subject: [PATCH] tflite utils and tflite backend for utils/synthesis.py --- tf/convert_tacotron2_tflite.py | 44 ++++++++++++++++++++++++++++++++++ tf/utils/io.py | 42 ++++++++++++++++++++++++++++++++ tf/utils/tflite.py | 20 ++++++++++++++++ utils/synthesis.py | 44 ++++++++++++++++++++++++++++++++-- 4 files changed, 148 insertions(+), 2 deletions(-) create mode 100644 tf/convert_tacotron2_tflite.py create mode 100644 tf/utils/io.py create mode 100644 tf/utils/tflite.py diff --git a/tf/convert_tacotron2_tflite.py b/tf/convert_tacotron2_tflite.py new file mode 100644 index 00000000..28039376 --- /dev/null +++ b/tf/convert_tacotron2_tflite.py @@ -0,0 +1,44 @@ +# Convert Tensorflow Tacotron2 model to TF-Lite binary + +import tensorflow as tf +import argparse + +from TTS.utils.io import load_config +from TTS.utils.text.symbols import symbols, phonemes, make_symbols +from TTS.tf.utils.generic_utils import setup_model +from TTS.tf.utils.io import load_checkpoint +from TTS.tf.utils.tflite import convert_tacotron2_to_tflite + + +parser = argparse.ArgumentParser() +parser.add_argument('--tf_model', + type=str, + help='Path to target torch model to be converted to TF.') +parser.add_argument('--config_path', + type=str, + help='Path to config file of torch model.') +parser.add_argument('--output_path', + type=str, + help='path to tflite output binary.') +args = parser.parse_args() + +# Set constants +CONFIG = load_config(args.config_path) + +# load the model +c = CONFIG +num_speakers = 0 +num_chars = len(phonemes) if c.use_phonemes else len(symbols) +model = setup_model(num_chars, num_speakers, c, enable_tflite=True) +model.build_inference() +model = load_checkpoint(model, args.tf_model) +model.decoder.set_max_decoder_steps(1000) + +# create tflite model +tflite_model = convert_tacotron2_to_tflite(model) + +# save tflite binary +with open(args.output_path, 'wb') as f: + f.write(tflite_model) + +print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') diff --git a/tf/utils/io.py b/tf/utils/io.py new file mode 100644 index 00000000..78a56de4 --- /dev/null +++ b/tf/utils/io.py @@ -0,0 +1,42 @@ +import pickle +import datetime +import tensorflow as tf + + +def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): + state = { + 'model': model.weights, + 'optimizer': optimizer, + 'step': current_step, + 'epoch': epoch, + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': r + } + state.update(kwargs) + pickle.dump(state, open(output_path, 'wb')) + + +def load_checkpoint(model, checkpoint_path): + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + tf_vars = model.weights + for tf_var in tf_vars: + layer_name = tf_var.name + try: + chkp_var_value = chkp_var_dict[layer_name] + except KeyError: + class_name = list(chkp_var_dict.keys())[0].split("/")[0] + layer_name = f"{class_name}/{layer_name}" + chkp_var_value = chkp_var_dict[layer_name] + + tf.keras.backend.set_value(tf_var, chkp_var_value) + if 'r' in checkpoint.keys(): + model.decoder.set_r(checkpoint['r']) + return model + + +def load_tflite_model(tflite_path): + tflite_model = tf.lite.Interpreter(model_path=tflite_path) + tflite_model.allocate_tensors() + return tflite_model + diff --git a/tf/utils/tflite.py b/tf/utils/tflite.py new file mode 100644 index 00000000..a46c1dce --- /dev/null +++ b/tf/utils/tflite.py @@ -0,0 +1,20 @@ +import tensorflow as tf + + +def convert_tacotron2_to_tflite(model): + tacotron2_concrete_function = model.inference_tflite.get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [tacotron2_concrete_function] + ) + converter.experimental_new_converter = True + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS] + tflite_model = converter.convert() + return tflite_model + + +def load_tflite_model(tflite_path): + tflite_model = tf.lite.Interpreter(model_path=tflite_path) + tflite_model.allocate_tensors() + return tflite_model \ No newline at end of file diff --git a/utils/synthesis.py b/utils/synthesis.py index 03d7072e..056a7b46 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -70,6 +70,31 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No return decoder_output, postnet_output, alignments, stop_tokens +def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): + if CONFIG.use_gst and style_mel is not None: + raise NotImplementedError(' [!] GST inference not implemented for TfLite') + if truncated: + raise NotImplementedError(' [!] Truncated inference not implemented for TfLite') + if speaker_id is not None: + raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite') + # get input and output details + input_details = model.get_input_details() + output_details = model.get_output_details() + # reshape input tensor for the new input shape + model.resize_tensor_input(input_details[0]['index'], inputs.shape) + model.allocate_tensors() + detail = input_details[0] + input_shape = detail['shape'] + model.set_tensor(detail['index'], inputs) + # run the model + model.invoke() + # collect outputs + decoder_output = model.get_tensor(output_details[0]['index']) + postnet_output = model.get_tensor(output_details[1]['index']) + # tflite model only returns feature frames + return decoder_output, postnet_output, None, None + + def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() @@ -86,6 +111,12 @@ def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): return postnet_output, decoder_output, alignment, stop_tokens +def parse_outputs_tflite(postnet_output, decoder_output): + postnet_output = postnet_output[0] + decoder_output = decoder_output[0] + return postnet_output, decoder_output + + def trim_silence(wav, ap): return wav[:ap.find_endpoint(wav)] @@ -164,22 +195,31 @@ def synthesis(model, style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) - else: + elif backend == 'tf': # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) + elif backend == 'tflite': + style_mel = numpy_to_tf(style_mel, tf.float32) + inputs = numpy_to_tf(inputs, tf.int32) + inputs = tf.expand_dims(inputs, 0) # synthesize voice if backend == 'torch': decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( model, inputs, CONFIG, truncated, speaker_id, style_mel) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( postnet_output, decoder_output, alignments, stop_tokens) - else: + elif backend == 'tf': decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( model, inputs, CONFIG, truncated, speaker_id, style_mel) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( postnet_output, decoder_output, alignments, stop_tokens) + elif backend == 'tflite': + decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output = parse_outputs_tflite( + postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None