diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index 7bd507fb..3b96f270 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -5,12 +5,10 @@ from torch import nn from TTS.utils.audio import TorchSTFT from TTS.utils.io import load_fsspec -# import torchaudio +import torchaudio - - -class PreEmphasis(torch.nn.Module): +class PreEmphasis(nn.Module): def __init__(self, coefficient=0.97): super().__init__() self.coefficient = coefficient @@ -114,29 +112,29 @@ class ResNetSpeakerEncoder(nn.Module): if self.use_torch_spec: self.torch_spec = torch.nn.Sequential( PreEmphasis(audio_config["preemphasis"]), - TorchSTFT( - n_fft=audio_config["fft_size"], - hop_length=audio_config["hop_length"], - win_length=audio_config["win_length"], - sample_rate=audio_config["sample_rate"], - window="hamming_window", - mel_fmin=0.0, - mel_fmax=None, - use_htk=True, - do_amp_to_db=False, - n_mels=audio_config["num_mels"], - power=2.0, - use_mel=True, - mel_norm=None, - ), - """torchaudio.transforms.MelSpectrogram( + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"], - ),""", + ) ) else: self.torch_spec = None