import os import unittest import numpy as np import tensorflow as tf import torch from TTS.tts.configs import Tacotron2Config from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model tf.get_logger().setLevel("INFO") # pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") c = Tacotron2Config() class TacotronTFTrainTest(unittest.TestCase): @staticmethod def generate_dummy_inputs(): chars_seq = torch.randint(0, 24, (8, 128)).long().to(device) chars_seq_lengths = torch.randint(100, 128, (8,)).long().to(device) chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0] mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy()) chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy()) mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids def test_train_step(self): """test forward pass""" ( chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids, ) = self.generate_dummy_inputs() for idx in mel_lengths: stop_targets[:, int(idx.item()) :, 0] = 1.0 stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) # training pass output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) # check model output shapes assert np.all(output[0].shape == mel_spec.shape) assert np.all(output[1].shape == mel_spec.shape) assert output[2].shape[2] == chars_seq.shape[1] assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) # inference pass output = model(chars_seq, training=False) def test_forward_attention( self, ): ( chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids, ) = self.generate_dummy_inputs() for idx in mel_lengths: stop_targets[:, int(idx.item()) :, 0] = 1.0 stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True) # training pass output = model(chars_seq, chars_seq_lengths, mel_spec, training=True) # check model output shapes assert np.all(output[0].shape == mel_spec.shape) assert np.all(output[1].shape == mel_spec.shape) assert output[2].shape[2] == chars_seq.shape[1] assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) # inference pass output = model(chars_seq, training=False) def test_tflite_conversion( self, ): # pylint:disable=no-self-use model = Tacotron2( num_chars=24, num_speakers=0, r=3, out_channels=80, decoder_output_dim=80, attn_type="original", attn_win=False, attn_norm="sigmoid", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, forward_attn_mask=False, location_attn=True, attn_K=0, separate_stopnet=True, bidirectional_decoder=False, enable_tflite=True, ) model.build_inference() convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True) # init tflite model tflite_model = load_tflite_model("test_tacotron2.tflite") # fake input inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) # pylint:disable=unexpected-keyword-arg # run inference # get input and output details input_details = tflite_model.get_input_details() output_details = tflite_model.get_output_details() # reshape input tensor for the new input shape tflite_model.resize_tensor_input( input_details[0]["index"], inputs.shape ) # pylint:disable=unexpected-keyword-arg tflite_model.allocate_tensors() detail = input_details[0] input_shape = detail["shape"] tflite_model.set_tensor(detail["index"], inputs) # run the tflite_model tflite_model.invoke() # collect outputs decoder_output = tflite_model.get_tensor(output_details[0]["index"]) postnet_output = tflite_model.get_tensor(output_details[1]["index"]) # remove tflite binary os.remove("test_tacotron2.tflite")