TTS/tests/test_vocoder_parallel_waveg...

28 lines
767 B
Python
Raw Normal View History

2020-07-17 09:36:36 +00:00
import numpy as np
import torch
2020-09-09 10:27:23 +00:00
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
2020-07-17 09:36:36 +00:00
def test_pwgan_generator():
model = ParallelWaveganGenerator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=[4, 4, 4, 4])
2020-07-17 11:01:45 +00:00
dummy_c = torch.rand((2, 80, 5))
2020-07-17 09:36:36 +00:00
output = model(dummy_c)
2020-07-17 11:01:45 +00:00
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
2020-07-17 09:36:36 +00:00
model.remove_weight_norm()
output = model.inference(dummy_c)
2020-07-17 11:01:45 +00:00
assert np.all(output.shape == (2, 1, (5 + 4) * 256))