mirror of https://github.com/coqui-ai/TTS.git
add tf tacotron2 test and edit test utils imports after utils
refactoringpull/10/head
parent
67397be1c0
commit
8805370645
|
@ -6,7 +6,8 @@ import torch as T
|
|||
from TTS.server.synthesizer import Synthesizer
|
||||
from TTS.tests import get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.io import load_config, save_checkpoint
|
||||
|
||||
|
||||
class DemoServerTest(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.datasets import TTSDataset
|
||||
from TTS.datasets.preprocess import ljspeech
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import MSELossMasked
|
||||
from TTS.models.tacotron2 import Tacotron2
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import os
|
||||
import copy
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import MSELossMasked
|
||||
from TTS.tf.models.tacotron2 import Tacotron2
|
||||
|
||||
#pylint: disable=unused-variable
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TacotronTFTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
''' test forward pass '''
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
input = tf.convert_to_tensor(input.cpu().numpy())
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy())
|
||||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
|
||||
# training pass
|
||||
output = model(input, input_lengths, mel_spec, training=True)
|
||||
|
||||
# check model output shapes
|
||||
assert np.all(output[0].shape == mel_spec.shape)
|
||||
assert np.all(output[1].shape == mel_spec.shape)
|
||||
assert output[2].shape[2] == input.shape[1]
|
||||
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
||||
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
||||
|
||||
# inference pass
|
||||
output = model(input, training=False)
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.layers.losses import L1LossMasked
|
||||
from TTS.models.tacotron import Tacotron
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
import unittest
|
||||
from TTS.utils.text import *
|
||||
from TTS.tests import get_tests_path
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
|
||||
|
@ -92,4 +92,4 @@ def test_text2phone():
|
|||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||
lang = "en-us"
|
||||
ph = text2phone(text, lang)
|
||||
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
|
||||
assert gt == ph, f"\n{phonemes} \n vs \n{gt}"
|
||||
|
|
Loading…
Reference in New Issue