diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index 8e699faf..e71c167a 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -129,3 +129,58 @@ class GlowTTSTrainTest(unittest.TestCase): count, param.shape, param, param_ref ) count += 1 + +class GlowTTSInferenceTest(unittest.TestCase): + @staticmethod + def test_inference(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) + + # create model + model = GlowTTS( + num_chars=32, + hidden_channels_enc=48, + hidden_channels_dec=48, + hidden_channels_dp=32, + out_channels=80, + encoder_type="rel_pos_transformer", + encoder_params={ + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 16, # 4 times the hidden_channels + "input_length": None, + }, + use_encoder_prenet=True, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.0, + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_squeeze=1, + sigmoid_scale=False, + mean_only=False, + ).to(device) + + model.eval() + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) + + # inference encoder and decoder with MAS + y, _, _, _, _, _, _ = model.inference_with_MAS( + input_dummy, input_lengths, mel_spec, mel_lengths, None + ) + + y_dec, _ = model.decoder_inference(mel_spec, mel_lengths + ) + + assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( + y.shape, y_dec.shape + )