Update spec extractor

pull/1324/head
Eren Gölge 2021-12-07 12:59:28 +00:00
parent 0a47a7eac0
commit 4d99fee3e2
1 changed files with 13 additions and 13 deletions

View File

@ -13,6 +13,7 @@ from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters
@ -20,21 +21,20 @@ use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset(
r,
c.text_cleaner,
outputs_per_step=r,
compute_linear_spec=False,
meta_data=meta_data,
samples=meta_data,
tokenizer=tokenizer,
ap=ap,
characters=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
min_text_len=c.min_text_len,
max_text_len=c.max_text_len,
min_audio_len=c.min_audio_len,
max_audio_len=c.max_audio_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
@ -44,7 +44,7 @@ def setup_loader(ap, r, verbose=False):
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(c.num_loader_workers)
dataset.sort_and_filter_items(c.get("sort_by_audio_len", default=False))
dataset.preprocess_samples()
loader = DataLoader(
dataset,
@ -75,8 +75,8 @@ def set_filename(wav_path, out_path):
def format_data(data):
# setup input data
text_input = data["text"]
text_lengths = data["text_lengths"]
text_input = data["token_id"]
text_lengths = data["token_id_lengths"]
mel_input = data["mel"]
mel_lengths = data["mel_lengths"]
item_idx = data["item_idxs"]