TTS/tests/vocoder_tests/test_vocoder_wavernn.py

34 lines
852 B
Python
Raw Normal View History

2021-04-12 09:47:39 +00:00
import random
2020-10-22 08:39:20 +00:00
import numpy as np
import torch
2021-04-12 09:47:39 +00:00
2020-10-22 08:39:20 +00:00
from TTS.vocoder.models.wavernn import WaveRNN
def test_wavernn():
model = WaveRNN(
rnn_dims=512,
fc_dims=512,
mode=10,
mulaw=False,
pad=2,
use_aux_net=True,
use_upsample_net=True,
upsample_factors=[4, 8, 8],
feat_dims=80,
compute_dims=128,
res_out_dims=128,
2020-10-25 09:04:24 +00:00
num_res_blocks=10,
2020-10-22 08:39:20 +00:00
hop_length=256,
sample_rate=22050,
)
dummy_x = torch.rand((2, 1280))
dummy_m = torch.rand((2, 80, 9))
y_size = random.randrange(20, 60)
dummy_y = torch.rand((80, y_size))
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 4 * 256)), output.shape
2020-11-12 15:33:29 +00:00
output = model.inference(dummy_y, True, 5500, 550)
2020-10-22 08:39:20 +00:00
assert np.all(output.shape == (256 * (y_size - 1),))