mirror of https://github.com/coqui-ai/TTS.git
Update VITS tests
parent
8b3ba02c95
commit
c0b40a0cb7
|
@ -3,17 +3,19 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from TTS.tts.datasets.formatters import ljspeech
|
||||||
|
|
||||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
|
||||||
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
|
||||||
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
@ -23,6 +25,28 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
class TestVits(unittest.TestCase):
|
class TestVits(unittest.TestCase):
|
||||||
|
def test_load_audio(self):
|
||||||
|
wav, sr = load_audio(WAV_FILE)
|
||||||
|
self.assertEqual(wav.shape, (1, 41885))
|
||||||
|
self.assertEqual(sr, 22050)
|
||||||
|
|
||||||
|
spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
|
||||||
|
mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False)
|
||||||
|
mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)
|
||||||
|
|
||||||
|
self.assertEqual((mel - mel2).abs().max(), 0)
|
||||||
|
self.assertEqual(spec.shape[0], mel.shape[0])
|
||||||
|
self.assertEqual(spec.shape[2], mel.shape[2])
|
||||||
|
|
||||||
|
spec_db = amp_to_db(spec)
|
||||||
|
spec_amp = db_to_amp(spec_db)
|
||||||
|
|
||||||
|
self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4)
|
||||||
|
|
||||||
|
def test_dataset(self):
|
||||||
|
"""TODO:"""
|
||||||
|
...
|
||||||
|
|
||||||
def test_init_multispeaker(self):
|
def test_init_multispeaker(self):
|
||||||
num_speakers = 10
|
num_speakers = 10
|
||||||
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
|
@ -107,10 +131,11 @@ class TestVits(unittest.TestCase):
|
||||||
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
||||||
|
mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device)
|
||||||
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||||
spec_lengths[-1] = spec.size(2)
|
spec_lengths[-1] = spec.size(2)
|
||||||
waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||||
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
|
||||||
|
|
||||||
def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2):
|
def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -139,7 +164,7 @@ class TestVits(unittest.TestCase):
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
config.model_args.spec_segment_size = 10
|
config.model_args.spec_segment_size = 10
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
|
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
||||||
self._check_forward_outputs(config, output_dict)
|
self._check_forward_outputs(config, output_dict)
|
||||||
|
@ -150,7 +175,7 @@ class TestVits(unittest.TestCase):
|
||||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||||
config.model_args.spec_segment_size = 10
|
config.model_args.spec_segment_size = 10
|
||||||
|
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
|
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||||
|
|
||||||
model = Vits(config).to(device)
|
model = Vits(config).to(device)
|
||||||
|
@ -171,7 +196,7 @@ class TestVits(unittest.TestCase):
|
||||||
config = VitsConfig(model_args=args)
|
config = VitsConfig(model_args=args)
|
||||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||||
d_vectors = torch.randn(batch_size, 256).to(device)
|
d_vectors = torch.randn(batch_size, 256).to(device)
|
||||||
output_dict = model.forward(
|
output_dict = model.forward(
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
|
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
|
||||||
|
@ -186,7 +211,7 @@ class TestVits(unittest.TestCase):
|
||||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
|
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||||
|
|
||||||
|
@ -221,7 +246,7 @@ class TestVits(unittest.TestCase):
|
||||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||||
config.audio.sample_rate = 16000
|
config.audio.sample_rate = 16000
|
||||||
|
|
||||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||||
|
|
||||||
|
@ -330,20 +355,25 @@ class TestVits(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_parameter_changes(model, model_ref):
|
def _check_parameter_changes(model, model_ref):
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()):
|
||||||
|
name = item1[0]
|
||||||
|
param = item1[1]
|
||||||
|
param_ref = item2[1]
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
count, param.shape, param, param_ref
|
name, param.shape, param, param_ref
|
||||||
)
|
)
|
||||||
count += 1
|
count = count + 1
|
||||||
|
|
||||||
def _create_batch(self, config, batch_size):
|
def _create_batch(self, config, batch_size):
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(config, batch_size)
|
input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size)
|
||||||
batch = {}
|
batch = {}
|
||||||
batch["text_input"] = input_dummy
|
batch["tokens"] = input_dummy
|
||||||
batch["text_lengths"] = input_lengths
|
batch["token_lens"] = input_lengths
|
||||||
batch["mel_lengths"] = mel_lengths
|
batch["spec_lens"] = mel_lengths
|
||||||
batch["linear_input"] = mel_spec.transpose(1, 2)
|
batch["mel_lens"] = mel_lengths
|
||||||
batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device)
|
batch["spec"] = spec
|
||||||
|
batch["mel"] = mel
|
||||||
|
batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device)
|
||||||
batch["d_vectors"] = None
|
batch["d_vectors"] = None
|
||||||
batch["speaker_ids"] = None
|
batch["speaker_ids"] = None
|
||||||
batch["language_ids"] = None
|
batch["language_ids"] = None
|
||||||
|
@ -351,33 +381,31 @@ class TestVits(unittest.TestCase):
|
||||||
|
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
# setup the model
|
# setup the model
|
||||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
model = Vits(config).to(device)
|
|
||||||
# create a batch
|
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
|
||||||
batch = self._create_batch(config, 1)
|
model = Vits(config).to(device)
|
||||||
# model to train
|
model.train()
|
||||||
criterions = model.get_criterion()
|
# model to train
|
||||||
criterions = [criterions[0].to(device), criterions[1].to(device)]
|
optimizers = model.get_optimizer()
|
||||||
# reference model to compare model weights
|
criterions = model.get_criterion()
|
||||||
model_ref = Vits(config).to(device)
|
criterions = [criterions[0].to(device), criterions[1].to(device)]
|
||||||
model.train()
|
# reference model to compare model weights
|
||||||
# pass the state to ref model
|
model_ref = Vits(config).to(device)
|
||||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
# # pass the state to ref model
|
||||||
count = 0
|
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
count = 0
|
||||||
assert (param - param_ref).sum() == 0, param
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
count += 1
|
assert (param - param_ref).sum() == 0, param
|
||||||
optimizers = model.get_optimizer()
|
count = count + 1
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
_, loss_dict = model.train_step(batch, criterions, 0)
|
batch = self._create_batch(config, 2)
|
||||||
loss = loss_dict["loss"]
|
for idx in [0, 1]:
|
||||||
loss.backward()
|
_, loss_dict = model.train_step(batch, criterions, idx)
|
||||||
optimizers[0].step()
|
loss_dict["loss"].backward()
|
||||||
|
optimizers[idx].step()
|
||||||
|
optimizers[idx].zero_grad()
|
||||||
|
|
||||||
_, loss_dict = model.train_step(batch, criterions, 1)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
loss.backward()
|
|
||||||
optimizers[1].step()
|
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
self._check_parameter_changes(model, model_ref)
|
self._check_parameter_changes(model, model_ref)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue