mirror of https://github.com/coqui-ai/TTS.git
Style update
parent
a89eb12aca
commit
d6e29ef98a
|
@ -70,7 +70,9 @@ class FFTransformerBlock(nn.Module):
|
|||
|
||||
|
||||
class FFTDurationPredictor:
|
||||
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
|
||||
def __init__(
|
||||
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
|
||||
): # pylint: disable=unused-argument
|
||||
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
||||
self.proj = nn.Linear(in_channels, 1)
|
||||
|
||||
|
|
|
@ -52,5 +52,3 @@ def prepare_stop_target(inputs, out_steps):
|
|||
|
||||
def pad_per_step(inputs, pad_len):
|
||||
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
||||
|
||||
|
||||
|
|
|
@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
wavs = data['waveform']
|
||||
text_input = data["text"]
|
||||
text_lengths = data["text_lengths"]
|
||||
speaker_name = data["speaker_names"]
|
||||
linear_input = data["linear"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
stop_target = data["stop_targets"]
|
||||
item_idx = data["item_idxs"]
|
||||
wavs = data["waveform"]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
|
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
text_input = data["text"]
|
||||
text_lengths = data["text_lengths"]
|
||||
speaker_name = data["speaker_names"]
|
||||
linear_input = data["linear"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
stop_target = data["stop_targets"]
|
||||
item_idx = data["item_idxs"]
|
||||
|
||||
avg_length = mel_lengths.numpy().mean()
|
||||
assert avg_length >= last_length
|
||||
|
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
text_input = data["text"]
|
||||
text_lengths = data["text_lengths"]
|
||||
speaker_name = data["speaker_names"]
|
||||
linear_input = data["linear"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
stop_target = data["stop_targets"]
|
||||
item_idx = data["item_idxs"]
|
||||
|
||||
# check mel_spec consistency
|
||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||
|
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
|
|||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data['text']
|
||||
text_lengths = data['text_lengths']
|
||||
speaker_name = data['speaker_names']
|
||||
linear_input = data['linear']
|
||||
mel_input = data['mel']
|
||||
mel_lengths = data['mel_lengths']
|
||||
stop_target = data['stop_targets']
|
||||
item_idx = data['item_idxs']
|
||||
text_input = data["text"]
|
||||
text_lengths = data["text_lengths"]
|
||||
speaker_name = data["speaker_names"]
|
||||
linear_input = data["linear"]
|
||||
mel_input = data["mel"]
|
||||
mel_lengths = data["mel_lengths"]
|
||||
stop_target = data["stop_targets"]
|
||||
item_idx = data["item_idxs"]
|
||||
|
||||
if mel_lengths[0] > mel_lengths[1]:
|
||||
idx = 0
|
||||
|
|
Loading…
Reference in New Issue