mirror of https://github.com/coqui-ai/TTS.git
Implement VitsAudioConfig (#1556)
* Implement VitsAudioConfig * Update VITS LJSpeech recipe * Update VITS VCTK recipe * Make style * Add missing decorator * Add missing param * Make style * Update recipes * Fix test * Bug fix * Exclude tests folder * Make linter * Make stylepull/1739/head
parent
34b80e0280
commit
49bac724c0
|
@ -11,4 +11,5 @@ recursive-include TTS *.md
|
|||
recursive-include TTS *.py
|
||||
recursive-include TTS *.pyx
|
||||
recursive-include images *.png
|
||||
|
||||
recursive-exclude tests *
|
||||
prune tests*
|
||||
|
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.vits import VitsArgs
|
||||
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
|
|||
model_args (VitsArgs):
|
||||
Model architecture arguments. Defaults to `VitsArgs()`.
|
||||
|
||||
audio (VitsAudioConfig):
|
||||
Audio processing configuration. Defaults to `VitsAudioConfig()`.
|
||||
|
||||
grad_clip (List):
|
||||
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||
|
||||
|
@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
model: str = "vits"
|
||||
# model specific params
|
||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||
audio: VitsAudioConfig = VitsAudioConfig()
|
||||
|
||||
# optimizer
|
||||
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||
|
|
|
@ -137,7 +137,7 @@ class SSIMLoss(torch.nn.Module):
|
|||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
ssim_loss = torch.tensor([0.0])
|
||||
ssim_loss = torch.tensor([0.0])
|
||||
|
||||
return ssim_loss
|
||||
|
||||
|
|
|
@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
return spec
|
||||
|
||||
|
||||
#############################
|
||||
# CONFIGS
|
||||
#############################
|
||||
|
||||
|
||||
@dataclass
|
||||
class VitsAudioConfig(Coqpit):
|
||||
fft_size: int = 1024
|
||||
sample_rate: int = 22050
|
||||
win_length: int = 1024
|
||||
hop_length: int = 256
|
||||
num_mels: int = 80
|
||||
mel_fmin: int = 0
|
||||
mel_fmax: int = None
|
||||
|
||||
|
||||
##############################
|
||||
# DATASET
|
||||
##############################
|
||||
|
|
|
@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
|
|||
"""
|
||||
if reduction == "none":
|
||||
return x
|
||||
elif reduction == "mean":
|
||||
if reduction == "mean":
|
||||
return x.mean(dim=0)
|
||||
elif reduction == "sum":
|
||||
if reduction == "sum":
|
||||
return x.sum(dim=0)
|
||||
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
|
||||
|
||||
|
|
|
@ -307,7 +307,7 @@ class Synthesizer(object):
|
|||
waveform = waveform.squeeze()
|
||||
|
||||
# trim silence
|
||||
if self.tts_config.audio["do_trim_silence"] is True:
|
||||
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
|
||||
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||
|
||||
wavs += list(waveform)
|
||||
|
|
|
@ -54,7 +54,6 @@ config = FastPitchConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -53,7 +53,6 @@ config = FastSpeechConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
|
|||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -68,7 +68,6 @@ config = Tacotron2Config(
|
|||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
seq_len_norm=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
|||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=45,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
|
@ -37,7 +23,7 @@ config = VitsConfig(
|
|||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=5,
|
||||
num_loader_workers=0,
|
||||
num_loader_workers=8,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
|
@ -52,6 +38,7 @@ config = VitsConfig(
|
|||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
cudnn_benchmark=False,
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
|
|
|
@ -3,11 +3,10 @@ from glob import glob
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
|
||||
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
|
@ -22,22 +21,13 @@ dataset_config = [
|
|||
for path in dataset_paths
|
||||
]
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=16000,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=False,
|
||||
trim_db=23.0,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=True,
|
||||
do_amp_to_db_linear=False,
|
||||
resample=False,
|
||||
)
|
||||
|
||||
vitsArgs = VitsArgs(
|
||||
|
@ -69,7 +59,6 @@ config = VitsConfig(
|
|||
use_language_weighted_sampler=True,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
min_audio_len=32 * 256 * 4,
|
||||
max_audio_len=160000,
|
||||
output_path=output_path,
|
||||
|
|
|
@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
|
|||
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
|
||||
"Vor dem 22. November 1963.",
|
||||
],
|
||||
sort_by_audio_len=True,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits
|
||||
from TTS.tts.models.vits import Vits, VitsAudioConfig
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.downloaders import download_thorsten_de
|
||||
|
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
|
|||
print("Downloading dataset")
|
||||
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=45,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
|
|
|
@ -2,11 +2,10 @@ import os
|
|||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.vits import Vits, VitsArgs
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
|
|||
)
|
||||
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
num_mels=80,
|
||||
preemphasis=0.0,
|
||||
ref_level_db=20,
|
||||
log_func="np.log",
|
||||
do_trim_silence=True,
|
||||
trim_db=23.0,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
spec_gain=1.0,
|
||||
signal_norm=False,
|
||||
do_amp_to_db_linear=False,
|
||||
resample=True,
|
||||
audio_config = VitsAudioConfig(
|
||||
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
|
||||
)
|
||||
|
||||
vitsArgs = VitsArgs(
|
||||
|
@ -62,6 +47,7 @@ config = VitsConfig(
|
|||
max_text_len=325, # change this if you have a larger VRAM than 16GB
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
cudnn_benchmark=False,
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
|
|
2
setup.py
2
setup.py
|
@ -90,7 +90,7 @@ setup(
|
|||
# ext_modules=find_cython_extensions(),
|
||||
# package
|
||||
include_package_data=True,
|
||||
packages=find_packages(include=["TTS*"]),
|
||||
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
|
||||
package_data={
|
||||
"TTS": [
|
||||
"VERSION",
|
||||
|
|
|
@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
|
|||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.tts.models.vits import (
|
||||
Vits,
|
||||
VitsArgs,
|
||||
VitsAudioConfig,
|
||||
amp_to_db,
|
||||
db_to_amp,
|
||||
load_audio,
|
||||
spec_to_mel,
|
||||
wav_to_mel,
|
||||
wav_to_spec,
|
||||
)
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||
|
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
|
|||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling(self):
|
||||
"""Upsampling by the decoder upsampling layers"""
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||
model_args = VitsArgs(
|
||||
num_chars=32,
|
||||
spec_segment_size=10,
|
||||
|
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
|
|||
interpolate_z=False,
|
||||
upsample_rates_decoder=[8, 8, 4, 2],
|
||||
)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
|
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
|
|||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling_interpolation(self):
|
||||
"""Upsampling by interpolation"""
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
audio_config = VitsAudioConfig(sample_rate=22050)
|
||||
model_args = VitsArgs(
|
||||
num_chars=32,
|
||||
spec_segment_size=10,
|
||||
encoder_sample_rate=11025,
|
||||
interpolate_z=True,
|
||||
upsample_rates_decoder=[8, 8, 2, 2],
|
||||
)
|
||||
config = VitsConfig(model_args=model_args, audio=audio_config)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
|
|
Loading…
Reference in New Issue