mirror of https://github.com/coqui-ai/TTS.git
28 lines
767 B
Python
28 lines
767 B
Python
import numpy as np
|
|
import torch
|
|
|
|
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
|
|
|
|
|
|
def test_pwgan_generator():
|
|
model = ParallelWaveganGenerator(
|
|
in_channels=1,
|
|
out_channels=1,
|
|
kernel_size=3,
|
|
num_res_blocks=30,
|
|
stacks=3,
|
|
res_channels=64,
|
|
gate_channels=128,
|
|
skip_channels=64,
|
|
aux_channels=80,
|
|
dropout=0.0,
|
|
bias=True,
|
|
use_weight_norm=True,
|
|
upsample_factors=[4, 4, 4, 4])
|
|
dummy_c = torch.rand((2, 80, 5))
|
|
output = model(dummy_c)
|
|
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
|
|
model.remove_weight_norm()
|
|
output = model.inference(dummy_c)
|
|
assert np.all(output.shape == (2, 1, (5 + 4) * 256))
|