Style update

pull/800/head
Eren Gölge 2021-09-10 08:30:33 +00:00
parent a89eb12aca
commit d6e29ef98a
3 changed files with 36 additions and 36 deletions

View File

@ -70,7 +70,9 @@ class FFTransformerBlock(nn.Module):
class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
def __init__(
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)

View File

@ -52,5 +52,3 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)

View File

@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
wavs = data['waveform']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
wavs = data["waveform"]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
avg_length = mel_lengths.numpy().mean()
assert avg_length >= last_length
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data['text']
text_lengths = data['text_lengths']
speaker_name = data['speaker_names']
linear_input = data['linear']
mel_input = data['mel']
mel_lengths = data['mel_lengths']
stop_target = data['stop_targets']
item_idx = data['item_idxs']
text_input = data["text"]
text_lengths = data["text_lengths"]
speaker_name = data["speaker_names"]
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
if mel_lengths[0] > mel_lengths[1]:
idx = 0