commenting TTSDataset.py

pull/10/head
Eren Golge 2019-11-19 12:39:31 +01:00
parent 79cca4ac80
commit ee788bc558
1 changed files with 4 additions and 1 deletions

View File

@ -176,6 +176,8 @@ class MyDataset(Dataset):
if isinstance(batch[0], collections.Mapping):
text_lenghts = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort(
torch.LongTensor(text_lenghts), dim=0, descending=True)
@ -187,6 +189,7 @@ class MyDataset(Dataset):
speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing]
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
@ -211,7 +214,7 @@ class MyDataset(Dataset):
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
# B x D x T --> B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)