From 037ec134536ca6dff7e64837167cbf12ea3a1fdc Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 14 Jun 2019 16:18:49 +0200 Subject: [PATCH] config update, audio.py update and modularize synthesize.py --- .compute | 6 +- config_tacotron_de.json | 198 +++++++++++++++++++++++----------------- datasets/preprocess.py | 11 ++- utils/audio.py | 17 +--- utils/synthesis.py | 95 +++++++++++-------- 5 files changed, 184 insertions(+), 143 deletions(-) diff --git a/.compute b/.compute index 54c96861..fd21fdaa 100644 --- a/.compute +++ b/.compute @@ -10,7 +10,7 @@ wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh sudo sh install.sh python3 setup.py develop # cp -R ${USER_DIR}/GermanData ../tmp/ -# python3 distribute.py --config_path config_tacotron_de.json --data_path ../tmp/GermanData/karlsson/ -cp -R ${USER_DIR}/Mozilla_22050 ../tmp/ -python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/ +python3 distribute.py --config_path config_tacotron_de.json --data_path ${USER_DIR}/GermanData/karlsson/ +# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/ +# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/ while true; do sleep 1000000; done diff --git a/config_tacotron_de.json b/config_tacotron_de.json index bf80bcc6..0a13c3eb 100644 --- a/config_tacotron_de.json +++ b/config_tacotron_de.json @@ -1,85 +1,117 @@ { - "run_name": "german-tacotron-tagent-bn", - "run_description": "train german", - - "audio":{ - // Audio processing parameters - "num_mels": 80, // size of the mel spec frame. - "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. - "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "frame_length_ms": 50, // stft window length in ms. - "frame_shift_ms": 12.5, // stft window hop-lengh in ms. - "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "min_level_db": -100, // normalization range - "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. - "power": 1.5, // value to sharpen wav signals after GL algorithm. - "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. - // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] - "symmetric_norm": false, // move normalization to range [-1, 1] - "max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) - }, - - "distributed":{ - "backend": "nccl", - "url": "tcp:\/\/localhost:54321" - }, - - "reinit_layers": [], - - "model": "Tacotron", // one of the model in models/ - "grad_clip": 1, // upper limit for gradients for clipping. - "epochs": 1000, // total number of epochs to train. - "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. - "lr_decay": false, // if true, Noam learning rate decaying is applied through training. - "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "windowing": false, // Enables attention windowing. Used only in eval mode. - "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. - "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. - "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". - "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. - "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. - "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. - "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. - "loss_masking": true, // enable / disable loss masking against the sequence padding. - "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. - "stopnet": true, // Train stopnet predicting the end of synthesis. - "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. - "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + "github_branch":"tacotron-gst-softmax", + "run_name": "german-tacotron-gst-softmax", + "run_description": "train german with all of the german dataset", - "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. - "eval_batch_size":16, - "r": 5, // Number of frames to predict for step. - "wd": 0.000001, // Weight decay weight. - "checkpoint": true, // If true, it saves checkpoints per "save_step" - "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. - "print_step": 10, // Number of steps to log traning on console. - "batch_group_size": 0, //Number of batches to shuffle after bucketing. - - "run_eval": false, - "test_sentences_file": "de_sentences.txt", // set a file to load sentences to be used for testing. If it is null then we use default english sentences. - "test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time. - "data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument - "meta_file_train": [ - "grune_haus/metadata.csv", - "kleine_lord/metadata.csv", - "toten_seelen/metadata.csv", - "werde_die_du_bist/metadata.csv" - ], // DATASET-RELATED: metafile for training dataloader. - "meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader. - "dataset": "mailabs", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py - "min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training - "max_seq_len": 200, // DATASET-RELATED: maximum text length - "output_path": "/media/erogol/data_ssd/Data/models/german/", // DATASET-RELATED: output path for all training outputs. - "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. - "num_val_loader_workers": 4, // number of evaluation data loader processes. - "phoneme_cache_path": "phoneme_cache", // phoneme computation is slow, therefore, it caches results in the given folder. - "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. - "phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages - "text_cleaner": "phoneme_cleaners" - } - \ No newline at end of file + "audio":{ + // Audio processing parameters + "num_mels": 80, // size of the mel spec frame. + "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. + "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "frame_length_ms": 50, // stft window length in ms. + "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "min_level_db": -100, // normalization range + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + // Normalization parameters + "signal_norm": true, // normalize the spec values in range [0, 1] + "symmetric_norm": false, // move normalization to range [-1, 1] + "max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + "do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + }, + + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], + + "model": "Tacotron", // one of the model in models/ + "grad_clip": 1, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_decay": false, // if true, Noam learning rate decaying is applied through training. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "windowing": false, // Enables attention windowing. Used only in eval mode. + "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. + "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". + "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. + "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. + "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. + "forward_attn_mask": true, + "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "stopnet": true, // Train stopnet predicting the end of synthesis. + "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "eval_batch_size":32, + "r": 5, // Number of frames to predict for step. + "wd": 0.000001, // Weight decay weight. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. + "print_step": 10, // Number of steps to log traning on console. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + + "run_eval": false, + "test_sentences_file": "de_sentences.txt", // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + "test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time. + "data_path": "/home/erogol/Data/m-ai-labs/de_DE/by_book/" , // DATASET-RELATED: can overwritten from command argument + "meta_file_train": [ + "/home/erogol/Data/m-ai-labs/de_DE/by_book/mix/erzaehlungen_poe/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/mix/auf_zwei_planeten/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/kleinzaches/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/spiegel_kaetzchen/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/herrnarnesschatz/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/maedchen_von_moorhof/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/koenigsgaukler/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/altehous/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/odysseus/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/undine/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/reise_tilsit/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/schmied_seines_glueckes/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/kammmacher/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/unterm_birnbaum/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/liebesbriefe/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/male/karlsson/sandmann/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/eva_k/kleine_lord/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/eva_k/toten_seelen/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/eva_k/werde_die_du_bist/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/eva_k/grune_haus/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/rebecca_braunert_plunkett/das_letzte_marchen/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/rebecca_braunert_plunkett/ferien_vom_ich/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/rebecca_braunert_plunkett/maerchen/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/rebecca_braunert_plunkett/mein_weg_als_deutscher_und_jude/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/caspar/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/sterben/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/weihnachtsabend/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/frankenstein/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/tschun/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/menschenhasser/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/grune_gesicht/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/tom_sawyer/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/ramona_deininger/alter_afrikaner/metadata.csv", + "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/angela_merkel/merkel_alone/metadata.csv" + ], // DATASET-RELATED: metafile for training dataloader. + "meta_file_val": null, // DATASET-RELATED: metafile for evaluation dataloader. + "dataset": "mailabs", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py + "min_seq_len": 15, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 200, // DATASET-RELATED: maximum text length + "output_path": "/media/erogol/data_ssd/Data/models/mozilla_models/", // DATASET-RELATED: output path for all training outputs. + "num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "phoneme_cache_path": "phoneme_cache", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. + "phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + "text_cleaner": "phoneme_cleaners" + } + \ No newline at end of file diff --git a/datasets/preprocess.py b/datasets/preprocess.py index 07c025df..1db1c189 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -63,7 +63,9 @@ def mailabs(root_path, meta_files): """Normalizes M-AI-Labs meta data files to TTS format""" if meta_files is None: meta_files = glob(root_path+"/**/metadata.csv", recursive=True) - folders = [os.path.dirname(f.strip()) for f in meta_files] + folders = [f.strip().split("/")[-2] for f in meta_files] + else: + folders = [f.strip().split("by_book")[1][1:] for f in meta_files] # meta_files = [f.strip() for f in meta_files.split(",")] items = [] for idx, meta_file in enumerate(meta_files): @@ -73,13 +75,12 @@ def mailabs(root_path, meta_files): with open(txt_file, 'r') as ttf: for line in ttf: cols = line.split('|') - wav_file = os.path.join(root_path, folder, 'wavs', - cols[0] + '.wav') + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav') if os.path.isfile(wav_file): - text = cols[1] + text = cols[1].strip() items.append([text, wav_file]) else: - continue + raise RuntimeError("> File %s is not exist!"%(wav_file)) return items diff --git a/utils/audio.py b/utils/audio.py index 93ecd684..e14f2b7e 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -216,32 +216,23 @@ class AudioProcessor(object): return librosa.effects.trim( wav, top_db=40, frame_length=1024, hop_length=256)[0] - def mulaw_encode(self, wav, qc): + @staticmethod + def mulaw_encode(wav, qc): mu = 2 ** qc - 1 - # wav_abs = np.minimum(np.abs(wav), 1.0) signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu) - # Quantize signal to the specified number of levels. signal = (signal + 1) / 2 * mu + 0.5 - return np.floor(signal,) + return np.floor(signal) @staticmethod def mulaw_decode(wav, qc): """Recovers waveform from quantized values.""" - # from IPython.core.debugger import set_trace - # set_trace() mu = 2 ** qc - 1 x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) return x - # mu = 2 ** qc - 1. - # # Map values back to [-1, 1]. - # # casted = wav.astype(np.float32) - # # signal = 2 * casted / mu - 1 - # # Perform inverse of mu-law transformation. - # magnitude = (1 / mu) * ((1 + mu) ** abs(wav) - 1) - # return np.sign(wav) * magnitude def load_wav(self, filename, encode=False): x, sr = sf.read(filename) + # x, sr = librosa.load(filename, sr=self.sample_rate) if self.do_trim_silence: x = self.trim_silence(x) # sr, x = io.wavfile.read(filename) diff --git a/utils/synthesis.py b/utils/synthesis.py index 5b8dc685..fce44dfe 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -10,36 +10,73 @@ from matplotlib import pylab as plt def text_to_seqvec(text, CONFIG, use_cuda): text_cleaner = [CONFIG.text_cleaner] + # text ot phonemes to sequence vector if CONFIG.use_phonemes: seq = np.asarray( - phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars), + phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, + CONFIG.enable_eos_bos_chars), dtype=np.int32) else: seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) + # torch tensor chars_var = torch.from_numpy(seq).unsqueeze(0) if use_cuda: chars_var = chars_var.cuda() return chars_var.long() -def compute_style_mel(style_wav, ap): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav))).unsqueeze(0) - return style_mel +def compute_style_mel(style_wav, ap, use_cuda): + print(style_wav) + style_mel = torch.FloatTensor(ap.melspectrogram( + ap.load_wav(style_wav))).unsqueeze(0) + if use_cuda: + return style_mel.cuda() + else: + return style_mel -def run_model(): - pass +def run_model(model, inputs, CONFIG, truncated, style_mel=None): + if CONFIG.model == "TacotronGST" and style_mel is not None: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs, style_mel) + else: + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + inputs) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + inputs) + return decoder_output, postnet_output, alignments, stop_tokens -def parse_outputs(): - pass +def parse_outputs(postnet_output, decoder_output, alignments): + postnet_output = postnet_output[0].data.cpu().numpy() + decoder_output = decoder_output[0].data.cpu().numpy() + alignment = alignments[0].cpu().data.numpy() + return postnet_output, decoder_output, alignment -def trim_silence(): - pass +def trim_silence(wav): + return wav[:ap.find_endpoint(wav)] -def synthesis(model, text, CONFIG, use_cuda, ap, style_wav=None, truncated=False, enable_eos_bos_chars=False, trim_silence=False): +def inv_spectrogram(postnet_output, ap, CONFIG): + if CONFIG.model in ["Tacotron", "TacotronGST"]: + wav = ap.inv_spectrogram(postnet_output.T) + else: + wav = ap.inv_mel_spectrogram(postnet_output.T) + return wav + + +def synthesis(model, + text, + CONFIG, + use_cuda, + ap, + style_wav=None, + truncated=False, + enable_eos_bos_chars=False, + trim_silence=False): """Synthesize voice for the given text. Args: @@ -57,38 +94,18 @@ def synthesis(model, text, CONFIG, use_cuda, ap, style_wav=None, truncated=False """ # GST processing if CONFIG.model == "TacotronGST" and style_wav is not None: - style_mel = compute_style_mel(style_wav, ap) - + style_mel = compute_style_mel(style_wav, ap, use_cuda) # preprocess the given text - text_cleaner = [CONFIG.text_cleaner] - if CONFIG.use_phonemes: - seq = np.asarray( - phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars), - dtype=np.int32) - else: - seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) - chars_var = torch.from_numpy(seq).unsqueeze(0) + inputs = text_to_seqvec(text, CONFIG, use_cuda) # synthesize voice - if CONFIG.model == "TacotronGST" and style_wav is not None: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - chars_var.long(), style_mel) - else: - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - chars_var.long()) - else: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - chars_var.long()) + decoder_output, postnet_output, alignments, stop_tokens = run_model( + model, inputs, CONFIG, truncated, style_mel) # convert outputs to numpy - postnet_output = postnet_output[0].data.cpu().numpy() - decoder_output = decoder_output[0].data.cpu().numpy() - alignment = alignments[0].cpu().data.numpy() + postnet_output, decoder_output, alignment = parse_outputs( + postnet_output, decoder_output, alignments) # plot results - if CONFIG.model in ["Tacotron", "TacotronGST"]: - wav = ap.inv_spectrogram(postnet_output.T) - else: - wav = ap.inv_mel_spectrogram(postnet_output.T) + wav = inv_spectrogram(postnet_output, ap, CONFIG) # trim silence if trim_silence: - wav = wav[:ap.find_endpoint(wav)] + wav = trim_silence(wav) return wav, alignment, decoder_output, postnet_output, stop_tokens \ No newline at end of file