mirror of https://github.com/coqui-ai/TTS.git
52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from TTS.vocoder.configs import WavernnConfig
|
|
from TTS.vocoder.models.wavernn import Wavernn, WavernnArgs
|
|
|
|
|
|
def test_wavernn():
|
|
config = WavernnConfig()
|
|
config.model_args = WavernnArgs(
|
|
rnn_dims=512,
|
|
fc_dims=512,
|
|
mode="mold",
|
|
mulaw=False,
|
|
pad=2,
|
|
use_aux_net=True,
|
|
use_upsample_net=True,
|
|
upsample_factors=[4, 8, 8],
|
|
feat_dims=80,
|
|
compute_dims=128,
|
|
res_out_dims=128,
|
|
num_res_blocks=10,
|
|
)
|
|
config.audio.hop_length = 256
|
|
config.audio.sample_rate = 2048
|
|
|
|
dummy_x = torch.rand((2, 1280))
|
|
dummy_m = torch.rand((2, 80, 9))
|
|
y_size = random.randrange(20, 60)
|
|
dummy_y = torch.rand((80, y_size))
|
|
|
|
# mode: mold
|
|
model = Wavernn(config)
|
|
output = model(dummy_x, dummy_m)
|
|
assert np.all(output.shape == (2, 1280, 30)), output.shape
|
|
|
|
# mode: gauss
|
|
config.model_args.mode = "gauss"
|
|
model = Wavernn(config)
|
|
output = model(dummy_x, dummy_m)
|
|
assert np.all(output.shape == (2, 1280, 2)), output.shape
|
|
|
|
# mode: quantized
|
|
config.model_args.mode = 4
|
|
model = Wavernn(config)
|
|
output = model(dummy_x, dummy_m)
|
|
assert np.all(output.shape == (2, 1280, 2**4)), output.shape
|
|
output = model.inference(dummy_y, True, 5500, 550)
|
|
assert np.all(output.shape == (256 * (y_size - 1),))
|