mirror of https://github.com/coqui-ai/TTS.git
linter fix
parent
6448c87a55
commit
da3478bef1
|
@ -8,7 +8,6 @@ tf.get_logger().setLevel('INFO')
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.tf.models.tacotron2 import Tacotron2
|
from TTS.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
|
||||||
from TTS.utils.synthesis import run_model_tflite, text_to_seqvec
|
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
#pylint: disable=unused-variable
|
||||||
|
|
||||||
|
@ -16,7 +15,7 @@ torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__)).replace('/tf/','/')
|
file_path = os.path.dirname(os.path.realpath(__file__)).replace('/tf/', '/')
|
||||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# inference pass
|
# inference pass
|
||||||
output = model(chars_seq, training=False)
|
output = model(chars_seq, training=False)
|
||||||
|
|
||||||
def test_tflite_conversion(self, ):
|
def test_tflite_conversion(self, ): #pylint:disable=no-self-use
|
||||||
model = Tacotron2(num_chars=24,
|
model = Tacotron2(num_chars=24,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
r=3,
|
r=3,
|
||||||
|
@ -114,13 +113,13 @@ class TacotronTFTrainTest(unittest.TestCase):
|
||||||
# init tflite model
|
# init tflite model
|
||||||
tflite_model = load_tflite_model('test_tacotron2.tflite')
|
tflite_model = load_tflite_model('test_tacotron2.tflite')
|
||||||
# fake input
|
# fake input
|
||||||
inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32)
|
inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) #pylint:disable=unexpected-keyword-arg
|
||||||
# run inference
|
# run inference
|
||||||
# get input and output details
|
# get input and output details
|
||||||
input_details = tflite_model.get_input_details()
|
input_details = tflite_model.get_input_details()
|
||||||
output_details = tflite_model.get_output_details()
|
output_details = tflite_model.get_output_details()
|
||||||
# reshape input tensor for the new input shape
|
# reshape input tensor for the new input shape
|
||||||
tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape)
|
tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) #pylint:disable=unexpected-keyword-arg
|
||||||
tflite_model.allocate_tensors()
|
tflite_model.allocate_tensors()
|
||||||
detail = input_details[0]
|
detail = input_details[0]
|
||||||
input_shape = detail['shape']
|
input_shape = detail['shape']
|
||||||
|
|
Loading…
Reference in New Issue