2020-05-30 16:09:25 +00:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
2020-09-09 10:27:23 +00:00
|
|
|
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
|
|
|
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
|
2020-05-30 16:09:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_melgan_discriminator():
|
|
|
|
model = MelganDiscriminator()
|
|
|
|
print(model)
|
|
|
|
dummy_input = torch.rand((4, 1, 256 * 10))
|
|
|
|
output, _ = model(dummy_input)
|
|
|
|
assert np.all(output.shape == (4, 1, 10))
|
|
|
|
|
|
|
|
|
|
|
|
def test_melgan_multi_scale_discriminator():
|
|
|
|
model = MelganMultiscaleDiscriminator()
|
|
|
|
print(model)
|
|
|
|
dummy_input = torch.rand((4, 1, 256 * 16))
|
|
|
|
scores, feats = model(dummy_input)
|
|
|
|
assert len(scores) == 3
|
|
|
|
assert len(scores) == len(feats)
|
2020-06-19 12:46:13 +00:00
|
|
|
assert np.all(scores[0].shape == (4, 1, 64))
|
2020-05-30 16:09:25 +00:00
|
|
|
assert np.all(feats[0][0].shape == (4, 16, 4096))
|
|
|
|
assert np.all(feats[0][1].shape == (4, 64, 1024))
|
2020-08-04 12:07:47 +00:00
|
|
|
assert np.all(feats[0][2].shape == (4, 256, 256))
|