mirror of https://github.com/coqui-ai/TTS.git
128 lines
4.6 KiB
Python
128 lines
4.6 KiB
Python
|
import os
|
||
|
import numpy as np
|
||
|
import collections
|
||
|
import torch
|
||
|
import random
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
from TTS.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos
|
||
|
from TTS.utils.data import prepare_data, prepare_tensor, prepare_stop_target
|
||
|
|
||
|
|
||
|
class MyDataset(Dataset):
|
||
|
def __init__(self,
|
||
|
ap,
|
||
|
meta_data,
|
||
|
voice_len=1.6,
|
||
|
num_speakers_in_batch=64,
|
||
|
num_utter_per_speaker=10,
|
||
|
skip_speakers=False,
|
||
|
verbose=False):
|
||
|
"""
|
||
|
Args:
|
||
|
ap (TTS.utils.AudioProcessor): audio processor object.
|
||
|
meta_data (list): list of dataset instances.
|
||
|
seq_len (int): voice segment length in seconds.
|
||
|
verbose (bool): print diagnostic information.
|
||
|
"""
|
||
|
self.items = meta_data
|
||
|
self.sample_rate = ap.sample_rate
|
||
|
self.voice_len = voice_len
|
||
|
self.seq_len = int(voice_len * self.sample_rate)
|
||
|
self.num_utter_per_speaker = num_utter_per_speaker
|
||
|
self.skip_speakers = skip_speakers
|
||
|
self.ap = ap
|
||
|
self.verbose = verbose
|
||
|
self.__parse_items()
|
||
|
if self.verbose:
|
||
|
print("\n > DataLoader initialization")
|
||
|
print(f" | > Number of instances : {len(self.items)}")
|
||
|
print(f" | > Sequence length: {self.seq_len}")
|
||
|
print(f" | > Num speakers: {len(self.speakers)}")
|
||
|
|
||
|
def load_wav(self, filename):
|
||
|
audio = self.ap.load_wav(filename)
|
||
|
return audio
|
||
|
|
||
|
def load_data(self, idx):
|
||
|
text, wav_file, speaker_name = self.items[idx]
|
||
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||
|
mel = self.ap.melspectrogram(wav).astype('float32')
|
||
|
# sample seq_len
|
||
|
|
||
|
assert text.size > 0, self.items[idx][1]
|
||
|
assert wav.size > 0, self.items[idx][1]
|
||
|
|
||
|
sample = {
|
||
|
'mel': mel,
|
||
|
'item_idx': self.items[idx][1],
|
||
|
'speaker_name': speaker_name
|
||
|
}
|
||
|
return sample
|
||
|
|
||
|
def __parse_items(self):
|
||
|
"""
|
||
|
Find unique speaker ids and create a dict mapping utterances from speaker id
|
||
|
"""
|
||
|
speakers = list(set([item[-1] for item in self.items]))
|
||
|
self.speaker_to_utters = {}
|
||
|
self.speakers = []
|
||
|
for speaker in speakers:
|
||
|
speaker_utters = [item[1] for item in self.items if item[2] == speaker]
|
||
|
if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
|
||
|
print(f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}.")
|
||
|
else:
|
||
|
self.speakers.append(speaker)
|
||
|
self.speaker_to_utters[speaker] = speaker_utters
|
||
|
|
||
|
def __len__(self):
|
||
|
return int(1e+10)
|
||
|
|
||
|
def __sample_speaker(self):
|
||
|
speaker = random.sample(self.speakers, 1)[0]
|
||
|
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||
|
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||
|
else:
|
||
|
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
|
||
|
return speaker, utters
|
||
|
|
||
|
def __sample_speaker_utterances(self, speaker):
|
||
|
"""
|
||
|
Sample all M utterances for the given speaker.
|
||
|
"""
|
||
|
feats = []
|
||
|
labels = []
|
||
|
for idx in range(self.num_utter_per_speaker):
|
||
|
# TODO:dummy but works
|
||
|
while True:
|
||
|
if len(self.speaker_to_utters[speaker]) > 0:
|
||
|
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
||
|
else:
|
||
|
self.speakers.remove(speaker)
|
||
|
speaker, _ = self.__sample_speaker()
|
||
|
continue
|
||
|
wav = self.load_wav(utter)
|
||
|
if wav.shape[0] - self.seq_len > 0:
|
||
|
break
|
||
|
else:
|
||
|
self.speaker_to_utters[speaker].remove(utter)
|
||
|
|
||
|
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||
|
mel = self.ap.melspectrogram(wav[offset:offset+self.seq_len])
|
||
|
feats.append(torch.FloatTensor(mel))
|
||
|
labels.append(speaker)
|
||
|
return feats, labels
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
speaker, _ = self.__sample_speaker()
|
||
|
return speaker
|
||
|
|
||
|
def collate_fn(self, batch):
|
||
|
labels = []
|
||
|
feats = []
|
||
|
for speaker in batch:
|
||
|
feats_, labels_ = self.__sample_speaker_utterances(speaker)
|
||
|
labels.append(labels_)
|
||
|
feats.extend(feats_)
|
||
|
feats = torch.stack(feats)
|
||
|
return feats.transpose(1, 2), labels
|