diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 2fe9bcc4..9b7ecee2 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -70,7 +70,9 @@ class FFTransformerBlock(nn.Module): class FFTDurationPredictor: - def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument + def __init__( + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + ): # pylint: disable=unused-argument self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) self.proj = nn.Linear(in_channels, 1) diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index 7a766958..d91a828e 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -52,5 +52,3 @@ def prepare_stop_target(inputs, out_steps): def pad_per_step(inputs, pad_len): return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) - - diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 0fbb6bde..18066ef3 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data['text'] - text_lengths = data['text_lengths'] - speaker_name = data['speaker_names'] - linear_input = data['linear'] - mel_input = data['mel'] - mel_lengths = data['mel_lengths'] - stop_target = data['stop_targets'] - item_idx = data['item_idxs'] - wavs = data['waveform'] + text_input = data["text"] + text_lengths = data["text_lengths"] + speaker_name = data["speaker_names"] + linear_input = data["linear"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + stop_target = data["stop_targets"] + item_idx = data["item_idxs"] + wavs = data["waveform"] neg_values = text_input[text_input < 0] check_count = len(neg_values) @@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data['text'] - text_lengths = data['text_lengths'] - speaker_name = data['speaker_names'] - linear_input = data['linear'] - mel_input = data['mel'] - mel_lengths = data['mel_lengths'] - stop_target = data['stop_targets'] - item_idx = data['item_idxs'] + text_input = data["text"] + text_lengths = data["text_lengths"] + speaker_name = data["speaker_names"] + linear_input = data["linear"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + stop_target = data["stop_targets"] + item_idx = data["item_idxs"] avg_length = mel_lengths.numpy().mean() assert avg_length >= last_length @@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data['text'] - text_lengths = data['text_lengths'] - speaker_name = data['speaker_names'] - linear_input = data['linear'] - mel_input = data['mel'] - mel_lengths = data['mel_lengths'] - stop_target = data['stop_targets'] - item_idx = data['item_idxs'] + text_input = data["text"] + text_lengths = data["text_lengths"] + speaker_name = data["speaker_names"] + linear_input = data["linear"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + stop_target = data["stop_targets"] + item_idx = data["item_idxs"] # check mel_spec consistency wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32) @@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase): for i, data in enumerate(dataloader): if i == self.max_loader_iter: break - text_input = data['text'] - text_lengths = data['text_lengths'] - speaker_name = data['speaker_names'] - linear_input = data['linear'] - mel_input = data['mel'] - mel_lengths = data['mel_lengths'] - stop_target = data['stop_targets'] - item_idx = data['item_idxs'] + text_input = data["text"] + text_lengths = data["text_lengths"] + speaker_name = data["speaker_names"] + linear_input = data["linear"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + stop_target = data["stop_targets"] + item_idx = data["item_idxs"] if mel_lengths[0] > mel_lengths[1]: idx = 0