linter fix

pull/10/head
erogol 2020-07-11 23:32:43 +02:00
parent 6448c87a55
commit da3478bef1
1 changed files with 4 additions and 5 deletions

View File

@ -8,7 +8,6 @@ tf.get_logger().setLevel('INFO')
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.tf.models.tacotron2 import Tacotron2 from TTS.tf.models.tacotron2 import Tacotron2
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model from TTS.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model
from TTS.utils.synthesis import run_model_tflite, text_to_seqvec
#pylint: disable=unused-variable #pylint: disable=unused-variable
@ -16,7 +15,7 @@ torch.manual_seed(1)
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_path = os.path.dirname(os.path.realpath(__file__)).replace('/tf/','/') file_path = os.path.dirname(os.path.realpath(__file__)).replace('/tf/', '/')
c = load_config(os.path.join(file_path, 'test_config.json')) c = load_config(os.path.join(file_path, 'test_config.json'))
@ -90,7 +89,7 @@ class TacotronTFTrainTest(unittest.TestCase):
# inference pass # inference pass
output = model(chars_seq, training=False) output = model(chars_seq, training=False)
def test_tflite_conversion(self, ): def test_tflite_conversion(self, ): #pylint:disable=no-self-use
model = Tacotron2(num_chars=24, model = Tacotron2(num_chars=24,
num_speakers=0, num_speakers=0,
r=3, r=3,
@ -114,13 +113,13 @@ class TacotronTFTrainTest(unittest.TestCase):
# init tflite model # init tflite model
tflite_model = load_tflite_model('test_tacotron2.tflite') tflite_model = load_tflite_model('test_tacotron2.tflite')
# fake input # fake input
inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) #pylint:disable=unexpected-keyword-arg
# run inference # run inference
# get input and output details # get input and output details
input_details = tflite_model.get_input_details() input_details = tflite_model.get_input_details()
output_details = tflite_model.get_output_details() output_details = tflite_model.get_output_details()
# reshape input tensor for the new input shape # reshape input tensor for the new input shape
tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) #pylint:disable=unexpected-keyword-arg
tflite_model.allocate_tensors() tflite_model.allocate_tensors()
detail = input_details[0] detail = input_details[0]
input_shape = detail['shape'] input_shape = detail['shape']