mirror of https://github.com/coqui-ai/TTS.git
Remove torchaudio requeriment
parent
2e516869a1
commit
d39200e69b
|
@ -1,7 +1,10 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
# import torchaudio
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -110,14 +113,29 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
TorchSTFT(
|
||||
n_fft=audio_config["fft_size"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
win_length=audio_config["win_length"],
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
window="hamming_window",
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=None,
|
||||
use_htk=True,
|
||||
do_amp_to_db=False,
|
||||
n_mels=audio_config["num_mels"],
|
||||
power=2.0,
|
||||
use_mel=True,
|
||||
mel_norm=None
|
||||
),
|
||||
'''torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
),'''
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
|
|
@ -4,7 +4,7 @@ from itertools import chain
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
# import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
@ -395,7 +395,7 @@ class Vits(BaseTTS):
|
|||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
self.speaker_manager.init_speaker_encoder(
|
||||
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
||||
|
@ -410,14 +410,17 @@ class Vits(BaseTTS):
|
|||
hasattr(self.speaker_encoder, "audio_config")
|
||||
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker consistency loss (SCL) you need to have the TTS model sampling rate ({}) equal to the speaker encoder sampling rate ({}) !".format(self.audio_config["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
|
||||
)
|
||||
'''self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
else:
|
||||
self.audio_transform = None
|
||||
self.audio_transform = None'''
|
||||
else:
|
||||
self.audio_transform = None
|
||||
# self.audio_transform = None
|
||||
self.speaker_encoder = None
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
|
@ -655,8 +658,8 @@ class Vits(BaseTTS):
|
|||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)
|
||||
'''if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)'''
|
||||
|
||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
use_mel=False,
|
||||
do_amp_to_db=False,
|
||||
spec_gain=1.0,
|
||||
power=None,
|
||||
use_htk=False,
|
||||
mel_norm="slaney"
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
|
@ -45,6 +48,9 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.use_mel = use_mel
|
||||
self.do_amp_to_db = do_amp_to_db
|
||||
self.spec_gain = spec_gain
|
||||
self.power = power
|
||||
self.use_htk = use_htk
|
||||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
|
@ -83,6 +89,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
|
||||
if self.power is not None:
|
||||
S = S ** self.power
|
||||
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
if self.do_amp_to_db:
|
||||
|
@ -91,7 +101,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
|
|
@ -26,5 +26,3 @@ unidic-lite==1.0.8
|
|||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||
fsspec>=2021.04.0
|
||||
pyworld
|
||||
webrtcvad
|
||||
torchaudio>=0.7
|
||||
|
|
Loading…
Reference in New Issue