Use torchaudio for ResNet speaker encoder

pull/1032/head
Eren Gölge 2021-12-13 16:23:57 +00:00
parent 45b2f8e42e
commit 7a987db62b
1 changed files with 19 additions and 21 deletions

View File

@ -5,12 +5,10 @@ from torch import nn
from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
# import torchaudio
import torchaudio
class PreEmphasis(torch.nn.Module):
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
@ -114,29 +112,29 @@ class ResNetSpeakerEncoder(nn.Module):
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
TorchSTFT(
n_fft=audio_config["fft_size"],
hop_length=audio_config["hop_length"],
win_length=audio_config["win_length"],
sample_rate=audio_config["sample_rate"],
window="hamming_window",
mel_fmin=0.0,
mel_fmax=None,
use_htk=True,
do_amp_to_db=False,
n_mels=audio_config["num_mels"],
power=2.0,
use_mel=True,
mel_norm=None,
),
"""torchaudio.transforms.MelSpectrogram(
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),""",
)
)
else:
self.torch_spec = None