Update VITS tests

pull/1324/head
Eren Gölge 2022-02-20 11:56:21 +01:00
parent 8b3ba02c95
commit c0b40a0cb7
1 changed files with 71 additions and 43 deletions

View File

@ -3,17 +3,19 @@ import os
import unittest
import torch
from TTS.tts.datasets.formatters import ljspeech
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from trainer.logging.tensorboard_logger import TensorboardLogger
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
torch.manual_seed(1)
@ -23,6 +25,28 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# pylint: disable=no-self-use
class TestVits(unittest.TestCase):
def test_load_audio(self):
wav, sr = load_audio(WAV_FILE)
self.assertEqual(wav.shape, (1, 41885))
self.assertEqual(sr, 22050)
spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False)
mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)
self.assertEqual((mel - mel2).abs().max(), 0)
self.assertEqual(spec.shape[0], mel.shape[0])
self.assertEqual(spec.shape[2], mel.shape[2])
spec_db = amp_to_db(spec)
spec_amp = db_to_amp(spec_db)
self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4)
def test_dataset(self):
"""TODO:"""
...
def test_init_multispeaker(self):
num_speakers = 10
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
@ -107,10 +131,11 @@ class TestVits(unittest.TestCase):
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
input_lengths[-1] = 128
spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device)
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
return input_dummy, input_lengths, spec, spec_lengths, waveform
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2):
self.assertEqual(
@ -139,7 +164,7 @@ class TestVits(unittest.TestCase):
num_speakers = 0
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
model = Vits(config).to(device)
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
self._check_forward_outputs(config, output_dict)
@ -150,7 +175,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
config.model_args.spec_segment_size = 10
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
model = Vits(config).to(device)
@ -171,7 +196,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model.train()
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device)
output_dict = model.forward(
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
@ -186,7 +211,7 @@ class TestVits(unittest.TestCase):
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
@ -221,7 +246,7 @@ class TestVits(unittest.TestCase):
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
config.audio.sample_rate = 16000
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
@ -330,20 +355,25 @@ class TestVits(unittest.TestCase):
@staticmethod
def _check_parameter_changes(model, model_ref):
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()):
name = item1[0]
param = item1[1]
param_ref = item2[1]
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
name, param.shape, param, param_ref
)
count += 1
count = count + 1
def _create_batch(self, config, batch_size):
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size)
batch = {}
batch["text_input"] = input_dummy
batch["text_lengths"] = input_lengths
batch["mel_lengths"] = mel_lengths
batch["linear_input"] = mel_spec.transpose(1, 2)
batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device)
batch["tokens"] = input_dummy
batch["token_lens"] = input_lengths
batch["spec_lens"] = mel_lengths
batch["mel_lens"] = mel_lengths
batch["spec"] = spec
batch["mel"] = mel
batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device)
batch["d_vectors"] = None
batch["speaker_ids"] = None
batch["language_ids"] = None
@ -351,33 +381,31 @@ class TestVits(unittest.TestCase):
def test_train_step(self):
# setup the model
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
model = Vits(config).to(device)
# create a batch
batch = self._create_batch(config, 1)
# model to train
criterions = model.get_criterion()
criterions = [criterions[0].to(device), criterions[1].to(device)]
# reference model to compare model weights
model_ref = Vits(config).to(device)
model.train()
# pass the state to ref model
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizers = model.get_optimizer()
for _ in range(5):
_, loss_dict = model.train_step(batch, criterions, 0)
loss = loss_dict["loss"]
loss.backward()
optimizers[0].step()
with torch.autograd.set_detect_anomaly(True):
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
model = Vits(config).to(device)
model.train()
# model to train
optimizers = model.get_optimizer()
criterions = model.get_criterion()
criterions = [criterions[0].to(device), criterions[1].to(device)]
# reference model to compare model weights
model_ref = Vits(config).to(device)
# # pass the state to ref model
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count = count + 1
for _ in range(5):
batch = self._create_batch(config, 2)
for idx in [0, 1]:
_, loss_dict = model.train_step(batch, criterions, idx)
loss_dict["loss"].backward()
optimizers[idx].step()
optimizers[idx].zero_grad()
_, loss_dict = model.train_step(batch, criterions, 1)
loss = loss_dict["loss"]
loss.backward()
optimizers[1].step()
# check parameter changes
self._check_parameter_changes(model, model_ref)