mirror of https://github.com/coqui-ai/TTS.git
Add `PitchExtractor` and return dict by `collate`
parent
debf772ec5
commit
648655fa03
|
@ -130,6 +130,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
if compute_f0:
|
||||
self.pitch_extractor = PitchExtractor(self.items, verbose=verbose)
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(" | > Use phonemes: {}".format(self.use_phonemes))
|
||||
|
@ -247,8 +249,8 @@ class TTSDataset(Dataset):
|
|||
|
||||
pitch = None
|
||||
if self.compute_f0:
|
||||
pitch = self._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.normalize_pitch(pitch)
|
||||
pitch = self.pitch_extractor._load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.normalize_pitch(pitch)
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
|
@ -317,96 +319,6 @@ class TTSDataset(Dataset):
|
|||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
################
|
||||
# Pitch Methods
|
||||
###############
|
||||
# TODO: Refactor Pitch methods into a separate class
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
return pitch_file
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
pitch = ap.compute_f0(wav)
|
||||
if pitch_file:
|
||||
np.save(pitch_file, pitch)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def compute_pitch_stats(pitch_vecs):
|
||||
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
|
||||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def normalize_pitch(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch -= self.mean
|
||||
pitch /= self.std
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _pitch_worker(args):
|
||||
item = args[0]
|
||||
ap = args[1]
|
||||
cache_path = args[2]
|
||||
_, wav_file, *_ = item
|
||||
pitch_file = TTSDataset.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = TTSDataset._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
return pitch
|
||||
return None
|
||||
|
||||
def compute_pitch(self, cache_path, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Computing pitch features ...")
|
||||
if num_workers == 0:
|
||||
pitch_vecs = []
|
||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
pitch_vecs = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(TTSDataset._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||
|
||||
def load_pitch_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"]
|
||||
self.std = stats["std"]
|
||||
|
||||
###################
|
||||
# End Pitch Methods
|
||||
###################
|
||||
|
||||
def sort_and_filter_items(self, by_audio_len=False):
|
||||
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
|
||||
range.
|
||||
|
@ -588,22 +500,22 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
attns = None
|
||||
# TODO: return dictionary
|
||||
return (
|
||||
text,
|
||||
text_lenghts,
|
||||
speaker_names,
|
||||
linear,
|
||||
mel,
|
||||
mel_lengths,
|
||||
stop_targets,
|
||||
item_idxs,
|
||||
d_vectors,
|
||||
speaker_ids,
|
||||
attns,
|
||||
wav_padded,
|
||||
raw_text,
|
||||
pitch,
|
||||
)
|
||||
return {
|
||||
"text": text,
|
||||
"text_lengths": text_lenghts,
|
||||
"speaker_names": speaker_names,
|
||||
"linear": linear,
|
||||
"mel": mel,
|
||||
"mel_lengths": mel_lengths,
|
||||
"stop_targets": stop_targets,
|
||||
"item_idxs": item_idxs,
|
||||
"d_vectors": d_vectors,
|
||||
"speaker_ids": speaker_ids,
|
||||
"attns": attns,
|
||||
"waveform": wav_padded,
|
||||
"raw_text": raw_text,
|
||||
"pitch": pitch,
|
||||
}
|
||||
|
||||
raise TypeError(
|
||||
(
|
||||
|
@ -613,3 +525,103 @@ class TTSDataset(Dataset):
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class PitchExtractor:
|
||||
"""Pitch Extractor for computing F0 from wav files.
|
||||
|
||||
Args:
|
||||
items (List[List]): Dataset samples.
|
||||
verbose (bool): Whether to print the progress.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: List[List],
|
||||
verbose=False,
|
||||
):
|
||||
self.items = items
|
||||
self.verbose = verbose
|
||||
self.mean = None
|
||||
self.std = None
|
||||
|
||||
@staticmethod
|
||||
def create_pitch_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
|
||||
return pitch_file
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
pitch = ap.compute_f0(wav)
|
||||
if pitch_file:
|
||||
np.save(pitch_file, pitch)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def compute_pitch_stats(pitch_vecs):
|
||||
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
|
||||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def normalize_pitch(self, pitch):
|
||||
zero_idxs = np.where(pitch == 0.0)[0]
|
||||
pitch -= self.mean
|
||||
pitch /= self.std
|
||||
pitch[zero_idxs] = 0.0
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _load_or_compute_pitch(ap, wav_file, cache_path):
|
||||
"""
|
||||
compute pitch and return a numpy array of pitch values
|
||||
"""
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
else:
|
||||
pitch = np.load(pitch_file)
|
||||
return pitch
|
||||
|
||||
@staticmethod
|
||||
def _pitch_worker(args):
|
||||
item = args[0]
|
||||
ap = args[1]
|
||||
cache_path = args[2]
|
||||
_, wav_file, *_ = item
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
return pitch
|
||||
return None
|
||||
|
||||
def compute_pitch(self, cache_path, num_workers=0):
|
||||
"""Compute the input sequences with multi-processing.
|
||||
Call it before passing dataset to the data loader to cache the input sequences for faster data loading."""
|
||||
if not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Computing pitch features ...")
|
||||
if num_workers == 0:
|
||||
pitch_vecs = []
|
||||
for _, item in enumerate(tqdm.tqdm(self.items)):
|
||||
pitch_vecs += [self._pitch_worker([item, self.ap, cache_path])]
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
pitch_vecs = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(PitchExtractor._pitch_worker, [[item, self.ap, cache_path] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
pitch_mean, pitch_std = self.compute_pitch_stats(pitch_vecs)
|
||||
pitch_stats = {"mean": pitch_mean, "std": pitch_std}
|
||||
np.save(os.path.join(cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)
|
||||
|
||||
def load_pitch_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "pitch_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"]
|
||||
self.std = stats["std"]
|
||||
|
|
|
@ -104,19 +104,19 @@ class BaseTTS(BaseModel):
|
|||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3]
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
item_idx = batch[7]
|
||||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
waveform = batch[11]
|
||||
pitch = batch[13]
|
||||
text_input = batch["text"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
item_idx = batch["item_idxs"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -201,7 +201,7 @@ class BaseTTS(BaseModel):
|
|||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
comnpute_f0=config.get("compute_f0", False),
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
|
@ -252,8 +252,8 @@ class BaseTTS(BaseModel):
|
|||
# compute pitch frames and write to files.
|
||||
if config.compute_f0 and rank in [None, 0]:
|
||||
if not os.path.exists(config.f0_cache_path):
|
||||
dataset.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
||||
dataset.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
if num_gpus > 1:
|
||||
|
|
|
@ -134,9 +134,9 @@ class GlowTTS(BaseTTS):
|
|||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts::math:` B`
|
||||
- x_lenghts::math:`B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths::math:` B`
|
||||
- y_lengths::math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
|
|
Loading…
Reference in New Issue