diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index d17c84d4..478d9c00 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -7,6 +7,29 @@ from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram +def interpolate_vocoder_input(scale_factor, spec): + """Interpolate spectrogram by the scale factor. + It is mainly used to match the sampling rates of + the tts and vocoder models. + + Args: + scale_factor (float): scale factor to interpolate the spectrogram + spec (np.array): spectrogram to be interpolated + + Returns: + torch.tensor: interpolated spectrogram. + """ + print(" > before interpolation :", spec.shape) + spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) + spec = torch.nn.functional.interpolate(spec, + scale_factor=scale_factor, + recompute_scale_factor=True, + mode='bilinear', + align_corners=False).squeeze(0) + print(" > after interpolation :", spec.shape) + return spec + + def plot_results(y_hat, y, ap, global_step, name_prefix): """ Plot vocoder model results """