Bug fix in MP3 and FLAC compute length on TTSDataset (#3092)

* Bug Fix on XTTS load

* Bug fix in MP3 length on TTSDataset

* Update TTS/tts/datasets/dataset.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Uses mutagen for all audio formats

* Add dataloader test wit hall supported audio formats

* Use mutagen.File

* Update

* Fix aux unit tests

* Bug fixe on unit tests

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
pull/3487/head
Edresson Casanova 2023-12-27 13:23:43 -03:00 committed by GitHub
parent 55c7063724
commit 5dcc16d193
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
71 changed files with 154 additions and 100 deletions

View File

@ -13,6 +13,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
import mutagen
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")
@ -42,6 +44,15 @@ def string2filename(string):
return filename
def get_audio_size(audiopath):
extension = audiopath.rpartition(".")[-1].lower()
if extension not in {"mp3", "wav", "flac"}:
raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!")
audio_info = mutagen.File(audiopath).info
return int(audio_info.length * audio_info.sample_rate)
class TTSDataset(Dataset):
def __init__(
self,
@ -176,7 +187,7 @@ class TTSDataset(Dataset):
lens = []
for item in self.samples:
_, wav_file, *_ = _parse_sample(item)
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
audio_len = get_audio_size(wav_file)
lens.append(audio_len)
return lens
@ -295,7 +306,7 @@ class TTSDataset(Dataset):
def _compute_lengths(samples):
new_samples = []
for item in samples:
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio
audio_length = get_audio_size(item["audio_file"])
text_lenght = len(item["text"])
item["audio_length"] = audio_length
item["text_length"] = text_lenght

View File

@ -756,11 +756,13 @@ class Xtts(BaseTTS):
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
if speaker_file_path is None and checkpoint_dir is not None:
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth")
self.language_manager = LanguageManager(config)
self.speaker_manager = None
if os.path.exists(speaker_file_path):
if speaker_file_path is not None and os.path.exists(speaker_file_path):
self.speaker_manager = SpeakerManager(speaker_file_path)
if os.path.exists(vocab_path):

View File

@ -17,6 +17,7 @@ pyyaml>=6.0
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
aiohttp>=3.8.1
packaging>=23.1
mutagen==1.47.0
# deps for examples
flask>=2.0.1
# deps for inference

View File

@ -0,0 +1,9 @@
audio_file|text|transcription|speaker_name
wavs/LJ001-0001.flac|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|ljspeech-0
wavs/LJ001-0002.flac|in being comparatively modern.|in being comparatively modern.|ljspeech-0
wavs/LJ001-0003.flac|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|ljspeech-1
wavs/LJ001-0004.flac|produced the block books, which were the immediate predecessors of the true printed book,|produced the block books, which were the immediate predecessors of the true printed book,|ljspeech-1
wavs/LJ001-0005.flac|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2
wavs/LJ001-0006.flac|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2
wavs/LJ001-0007.flac|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3
wavs/LJ001-0008.flac|has never been surpassed.|has never been surpassed.|ljspeech-3
Can't render this file because it contains an unexpected character in line 8 and column 86.

View File

@ -0,0 +1,9 @@
audio_file|text|transcription|speaker_name
wavs/LJ001-0001.mp3|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|ljspeech-0
wavs/LJ001-0002.mp3|in being comparatively modern.|in being comparatively modern.|ljspeech-0
wavs/LJ001-0003.mp3|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|ljspeech-1
wavs/LJ001-0004.mp3|produced the block books, which were the immediate predecessors of the true printed book,|produced the block books, which were the immediate predecessors of the true printed book,|ljspeech-1
wavs/LJ001-0005.mp3|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2
wavs/LJ001-0006.mp3|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2
wavs/LJ001-0007.mp3|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3
wavs/LJ001-0008.mp3|has never been surpassed.|has never been surpassed.|ljspeech-3
Can't render this file because it contains an unexpected character in line 8 and column 85.

View File

@ -0,0 +1,9 @@
audio_file|text|transcription|speaker_name
wavs/LJ001-0001.wav|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition|ljspeech-0
wavs/LJ001-0002.wav|in being comparatively modern.|in being comparatively modern.|ljspeech-0
wavs/LJ001-0003.wav|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process|ljspeech-1
wavs/LJ001-0004.wav|produced the block books, which were the immediate predecessors of the true printed book,|produced the block books, which were the immediate predecessors of the true printed book,|ljspeech-1
wavs/LJ001-0005.wav|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2
wavs/LJ001-0006.wav|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2
wavs/LJ001-0007.wav|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3
wavs/LJ001-0008.wav|has never been surpassed.|has never been surpassed.|ljspeech-3
Can't render this file because it contains an unexpected character in line 8 and column 85.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -21,15 +21,30 @@ os.makedirs(OUTPATH, exist_ok=True)
c = BaseTTSConfig(text_cleaner="english_cleaners", num_loader_workers=0, batch_size=2, use_noise_augment=False)
c.r = 5
c.data_path = os.path.join(get_tests_data_path(), "ljspeech/")
ok_ljspeech = os.path.exists(c.data_path)
dataset_config = BaseDatasetConfig(
formatter="ljspeech_test", # ljspeech_test to multi-speaker
meta_file_train="metadata.csv",
dataset_config_wav = BaseDatasetConfig(
formatter="coqui", # ljspeech_test to multi-speaker
meta_file_train="metadata_wav.csv",
meta_file_val=None,
path=c.data_path,
language="en",
)
dataset_config_mp3 = BaseDatasetConfig(
formatter="coqui", # ljspeech_test to multi-speaker
meta_file_train="metadata_mp3.csv",
meta_file_val=None,
path=c.data_path,
language="en",
)
dataset_config_flac = BaseDatasetConfig(
formatter="coqui", # ljspeech_test to multi-speaker
meta_file_train="metadata_flac.csv",
meta_file_val=None,
path=c.data_path,
language="en",
)
dataset_configs = [dataset_config_wav, dataset_config_mp3, dataset_config_flac]
DATA_EXIST = True
if not os.path.exists(c.data_path):
@ -44,11 +59,10 @@ class TestTTSDataset(unittest.TestCase):
self.max_loader_iter = 4
self.ap = AudioProcessor(**c.audio)
def _create_dataloader(self, batch_size, r, bgs, start_by_longest=False):
def _create_dataloader(self, batch_size, r, bgs, dataset_config, start_by_longest=False, preprocess_samples=False):
# load dataset
meta_data_train, meta_data_eval = load_tts_samples(dataset_config, eval_split=True, eval_split_size=0.2)
items = meta_data_train + meta_data_eval
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset(
outputs_per_step=r,
@ -64,6 +78,11 @@ class TestTTSDataset(unittest.TestCase):
max_audio_len=c.max_audio_len,
start_by_longest=start_by_longest,
)
# add preprocess to force the length computation
if preprocess_samples:
dataset.preprocess_samples()
dataloader = DataLoader(
dataset,
batch_size=batch_size,
@ -75,9 +94,8 @@ class TestTTSDataset(unittest.TestCase):
return dataloader, dataset
def test_loader(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(1, 1, 0)
for dataset_config in dataset_configs:
dataloader, _ = self._create_dataloader(1, 1, 0, dataset_config, preprocess_samples=True)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
@ -104,8 +122,6 @@ class TestTTSDataset(unittest.TestCase):
# make sure that the computed mels and the waveform match and correctly computed
mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
# remove padding in mel-spectrogram
mel_dataloader = mel_input[0].T.numpy()[:, : mel_lengths[0]]
# guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
mel_new = mel_new[:, : mel_lengths[0]]
ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
@ -124,40 +140,38 @@ class TestTTSDataset(unittest.TestCase):
self.assertGreaterEqual(mel_input.min(), 0)
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataloader, dataset = self._create_dataloader(2, c.r, 16)
last_length = 0
frames = dataset.samples
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
avg_length = mel_lengths.numpy().mean()
dataloader.dataset.preprocess_samples()
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.samples):
if item != frames[idx]:
is_items_reordered = True
break
self.assertGreaterEqual(avg_length, last_length)
self.assertTrue(is_items_reordered)
dataloader, dataset = self._create_dataloader(2, c.r, 16, dataset_config_wav)
last_length = 0
frames = dataset.samples
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
avg_length = mel_lengths.numpy().mean()
dataloader.dataset.preprocess_samples()
is_items_reordered = False
for idx, item in enumerate(dataloader.dataset.samples):
if item != frames[idx]:
is_items_reordered = True
break
self.assertGreaterEqual(avg_length, last_length)
self.assertTrue(is_items_reordered)
def test_start_by_longest(self):
"""Test start_by_longest option.
Ther first item of the fist batch must be longer than all the other items.
"""
if ok_ljspeech:
dataloader, _ = self._create_dataloader(2, c.r, 0, True)
dataloader.dataset.preprocess_samples()
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
if i == 0:
max_len = mel_lengths[0]
print(mel_lengths)
self.assertTrue(all(max_len >= mel_lengths))
dataloader, _ = self._create_dataloader(2, c.r, 0, dataset_config_wav, start_by_longest=True)
dataloader.dataset.preprocess_samples()
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
mel_lengths = data["mel_lengths"]
if i == 0:
max_len = mel_lengths[0]
print(mel_lengths)
self.assertTrue(all(max_len >= mel_lengths))
def test_padding_and_spectrograms(self):
def check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths):
@ -172,71 +186,70 @@ class TestTTSDataset(unittest.TestCase):
self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0])
self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0])
if ok_ljspeech:
dataloader, _ = self._create_dataloader(1, 1, 0)
dataloader, _ = self._create_dataloader(1, 1, 0, dataset_config_wav)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
mel = torch.FloatTensor(mel).contiguous()
mel_dl = mel_input[0]
# NOTE: Below needs to check == 0 but due to an unknown reason
# there is a slight difference between two matrices.
# TODO: Check this assert cond more in detail.
self.assertLess(abs(mel.T - mel_dl).max(), 1e-5)
# check mel_spec consistency
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
mel = torch.FloatTensor(mel).contiguous()
mel_dl = mel_input[0]
# NOTE: Below needs to check == 0 but due to an unknown reason
# there is a slight difference between two matrices.
# TODO: Check this assert cond more in detail.
self.assertLess(abs(mel.T - mel_dl).max(), 1e-5)
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_melspectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav")
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
wav = self.ap.inv_melspectrogram(mel_spec.T)
self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav")
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
# check linear-spec
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav")
# check the outputs
check_conditions(0, linear_input, mel_input, stop_target, mel_lengths)
# check the outputs
check_conditions(0, linear_input, mel_input, stop_target, mel_lengths)
# Test for batch size 2
dataloader, _ = self._create_dataloader(2, 1, 0)
# Test for batch size 2
dataloader, _ = self._create_dataloader(2, 1, 0, dataset_config_wav)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
linear_input = data["linear"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
stop_target = data["stop_targets"]
item_idx = data["item_idxs"]
# set id to the longest sequence in the batch
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# set id to the longest sequence in the batch
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the longer item in the batch
check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths)
# check the longer item in the batch
check_conditions(idx, linear_input, mel_input, stop_target, mel_lengths)
# check the other item in the batch
self.assertEqual(linear_input[1 - idx, -1].sum(), 0)
self.assertEqual(mel_input[1 - idx, -1].sum(), 0)
self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1)
self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1])
self.assertEqual(len(mel_lengths.shape), 1)
# check the other item in the batch
self.assertEqual(linear_input[1 - idx, -1].sum(), 0)
self.assertEqual(mel_input[1 - idx, -1].sum(), 0)
self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1)
self.assertEqual(stop_target[1, mel_lengths[1] :].sum(), stop_target.shape[1] - mel_lengths[1])
self.assertEqual(len(mel_lengths.shape), 1)
# check batch zero-frame conditions (zero-frame disabled)
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# check batch zero-frame conditions (zero-frame disabled)
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0