TTS/speaker_encoder/dataset.py

128 lines
4.6 KiB
Python
Raw Normal View History

2019-11-01 11:23:03 +00:00
import os
import numpy as np
import collections
import torch
import random
from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos
from TTS.utils.data import prepare_data, prepare_tensor, prepare_stop_target
class MyDataset(Dataset):
def __init__(self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False):
"""
Args:
ap (TTS.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
self.items = meta_data
self.sample_rate = ap.sample_rate
self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.__parse_items()
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype('float32')
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
sample = {
'mel': mel,
'item_idx': self.items[idx][1],
'speaker_name': speaker_name
}
return sample
def __parse_items(self):
"""
Find unique speaker ids and create a dict mapping utterances from speaker id
"""
speakers = list(set([item[-1] for item in self.items]))
self.speaker_to_utters = {}
self.speakers = []
for speaker in speakers:
speaker_utters = [item[1] for item in self.items if item[2] == speaker]
if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
print(f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}.")
else:
self.speakers.append(speaker)
self.speaker_to_utters[speaker] = speaker_utters
def __len__(self):
return int(1e+10)
def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
else:
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
"""
Sample all M utterances for the given speaker.
"""
feats = []
labels = []
for idx in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
if len(self.speaker_to_utters[speaker]) > 0:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
self.speakers.remove(speaker)
speaker, _ = self.__sample_speaker()
continue
wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0:
break
else:
self.speaker_to_utters[speaker].remove(utter)
offset = random.randint(0, wav.shape[0] - self.seq_len)
mel = self.ap.melspectrogram(wav[offset:offset+self.seq_len])
feats.append(torch.FloatTensor(mel))
labels.append(speaker)
return feats, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
return speaker
def collate_fn(self, batch):
labels = []
feats = []
for speaker in batch:
feats_, labels_ = self.__sample_speaker_utterances(speaker)
labels.append(labels_)
feats.extend(feats_)
feats = torch.stack(feats)
return feats.transpose(1, 2), labels