mirror of https://github.com/coqui-ai/TTS.git
34 lines
852 B
Python
34 lines
852 B
Python
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from TTS.vocoder.models.wavernn import WaveRNN
|
|
|
|
|
|
def test_wavernn():
|
|
model = WaveRNN(
|
|
rnn_dims=512,
|
|
fc_dims=512,
|
|
mode=10,
|
|
mulaw=False,
|
|
pad=2,
|
|
use_aux_net=True,
|
|
use_upsample_net=True,
|
|
upsample_factors=[4, 8, 8],
|
|
feat_dims=80,
|
|
compute_dims=128,
|
|
res_out_dims=128,
|
|
num_res_blocks=10,
|
|
hop_length=256,
|
|
sample_rate=22050,
|
|
)
|
|
dummy_x = torch.rand((2, 1280))
|
|
dummy_m = torch.rand((2, 80, 9))
|
|
y_size = random.randrange(20, 60)
|
|
dummy_y = torch.rand((80, y_size))
|
|
output = model(dummy_x, dummy_m)
|
|
assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape
|
|
output = model.inference(dummy_y, True, 5500, 550)
|
|
assert np.all(output.shape == (256 * (y_size - 1),))
|