Implement VitsAudioConfig (#1556)

* Implement VitsAudioConfig

* Update VITS LJSpeech recipe

* Update VITS VCTK recipe

* Make style

* Add missing decorator

* Add missing param

* Make style

* Update recipes

* Fix test

* Bug fix

* Exclude tests folder

* Make linter

* Make style
pull/1739/head
Eren Gölge 2022-07-12 18:49:58 +02:00 committed by GitHub
parent 34b80e0280
commit 49bac724c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 65 additions and 76 deletions

View File

@ -11,4 +11,5 @@ recursive-include TTS *.md
recursive-include TTS *.py
recursive-include TTS *.pyx
recursive-include images *.png
recursive-exclude tests *
prune tests*

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits import VitsArgs
from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
@dataclass
@ -16,6 +16,9 @@ class VitsConfig(BaseTTSConfig):
model_args (VitsArgs):
Model architecture arguments. Defaults to `VitsArgs()`.
audio (VitsAudioConfig):
Audio processing configuration. Defaults to `VitsAudioConfig()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
@ -94,6 +97,7 @@ class VitsConfig(BaseTTSConfig):
model: str = "vits"
# model specific params
model_args: VitsArgs = field(default_factory=VitsArgs)
audio: VitsAudioConfig = VitsAudioConfig()
# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])

View File

@ -137,7 +137,7 @@ class SSIMLoss(torch.nn.Module):
if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
ssim_loss = torch.tensor([0.0])
ssim_loss = torch.tensor([0.0])
return ssim_loss

View File

@ -200,6 +200,22 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
return spec
#############################
# CONFIGS
#############################
@dataclass
class VitsAudioConfig(Coqpit):
fft_size: int = 1024
sample_rate: int = 22050
win_length: int = 1024
hop_length: int = 256
num_mels: int = 80
mel_fmin: int = 0
mel_fmax: int = None
##############################
# DATASET
##############################

View File

@ -16,9 +16,9 @@ def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
"""
if reduction == "none":
return x
elif reduction == "mean":
if reduction == "mean":
return x.mean(dim=0)
elif reduction == "sum":
if reduction == "sum":
return x.sum(dim=0)
raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")

View File

@ -307,7 +307,7 @@ class Synthesizer(object):
waveform = waveform.squeeze()
# trim silence
if self.tts_config.audio["do_trim_silence"] is True:
if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]:
waveform = trim_silence(waveform, self.tts_model.ap)
wavs += list(waveform)

View File

@ -54,7 +54,6 @@ config = FastPitchConfig(
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],

View File

@ -53,7 +53,6 @@ config = FastSpeechConfig(
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],

View File

@ -46,7 +46,6 @@ config = SpeedySpeechConfig(
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],

View File

@ -68,7 +68,6 @@ config = Tacotron2Config(
print_step=25,
print_eval=True,
mixed_precision=False,
sort_by_audio_len=True,
seq_len_norm=True,
output_path=output_path,
datasets=[dataset_config],

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
@ -14,21 +13,8 @@ output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
audio_config = VitsAudioConfig(
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)
config = VitsConfig(
@ -37,7 +23,7 @@ config = VitsConfig(
batch_size=32,
eval_batch_size=16,
batch_group_size=5,
num_loader_workers=0,
num_loader_workers=8,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
@ -52,6 +38,7 @@ config = VitsConfig(
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
cudnn_benchmark=False,
)
# INITIALIZE THE AUDIO PROCESSOR

View File

@ -3,11 +3,10 @@ from glob import glob
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
@ -22,22 +21,13 @@ dataset_config = [
for path in dataset_paths
]
audio_config = BaseAudioConfig(
audio_config = VitsAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=False,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=True,
do_amp_to_db_linear=False,
resample=False,
)
vitsArgs = VitsArgs(
@ -69,7 +59,6 @@ config = VitsConfig(
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
min_audio_len=32 * 256 * 4,
max_audio_len=160000,
output_path=output_path,

View File

@ -60,7 +60,6 @@ config = SpeedySpeechConfig(
"Dieser Kuchen ist großartig. Er ist so lecker und feucht.",
"Vor dem 22. November 1963.",
],
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.downloaders import download_thorsten_de
@ -21,21 +20,13 @@ if not os.path.exists(dataset_config.path):
print("Downloading dataset")
download_thorsten_de(os.path.split(os.path.abspath(dataset_config.path))[0])
audio_config = BaseAudioConfig(
audio_config = VitsAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=45,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
)
config = VitsConfig(

View File

@ -2,11 +2,10 @@ import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
@ -17,22 +16,8 @@ dataset_config = BaseDatasetConfig(
)
audio_config = BaseAudioConfig(
sample_rate=22050,
win_length=1024,
hop_length=256,
num_mels=80,
preemphasis=0.0,
ref_level_db=20,
log_func="np.log",
do_trim_silence=True,
trim_db=23.0,
mel_fmin=0,
mel_fmax=None,
spec_gain=1.0,
signal_norm=False,
do_amp_to_db_linear=False,
resample=True,
audio_config = VitsAudioConfig(
sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)
vitsArgs = VitsArgs(
@ -62,6 +47,7 @@ config = VitsConfig(
max_text_len=325, # change this if you have a larger VRAM than 16GB
output_path=output_path,
datasets=[dataset_config],
cudnn_benchmark=False,
)
# INITIALIZE THE AUDIO PROCESSOR

View File

@ -90,7 +90,7 @@ setup(
# ext_modules=find_cython_extensions(),
# package
include_package_data=True,
packages=find_packages(include=["TTS*"]),
packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]),
package_data={
"TTS": [
"VERSION",

View File

@ -9,7 +9,17 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_test
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.models.vits import (
Vits,
VitsArgs,
VitsAudioConfig,
amp_to_db,
db_to_amp,
load_audio,
spec_to_mel,
wav_to_mel,
wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
@ -421,8 +431,10 @@ class TestVits(unittest.TestCase):
self._check_parameter_changes(model, model_ref)
def test_train_step_upsampling(self):
"""Upsampling by the decoder upsampling layers"""
# setup the model
with torch.autograd.set_detect_anomaly(True):
audio_config = VitsAudioConfig(sample_rate=22050)
model_args = VitsArgs(
num_chars=32,
spec_segment_size=10,
@ -430,7 +442,7 @@ class TestVits(unittest.TestCase):
interpolate_z=False,
upsample_rates_decoder=[8, 8, 4, 2],
)
config = VitsConfig(model_args=model_args)
config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device)
model.train()
# model to train
@ -459,10 +471,18 @@ class TestVits(unittest.TestCase):
self._check_parameter_changes(model, model_ref)
def test_train_step_upsampling_interpolation(self):
"""Upsampling by interpolation"""
# setup the model
with torch.autograd.set_detect_anomaly(True):
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
config = VitsConfig(model_args=model_args)
audio_config = VitsAudioConfig(sample_rate=22050)
model_args = VitsArgs(
num_chars=32,
spec_segment_size=10,
encoder_sample_rate=11025,
interpolate_z=True,
upsample_rates_decoder=[8, 8, 2, 2],
)
config = VitsConfig(model_args=model_args, audio=audio_config)
model = Vits(config).to(device)
model.train()
# model to train