mirror of https://github.com/coqui-ai/TTS.git
Update WaveGrad
parent
fd95926009
commit
3d5205d66f
|
@ -58,7 +58,7 @@ class Wavegrad(BaseVocoder):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.use_weight_norm = config.model_params.use_weight_norm
|
||||
self.hop_len = np.prod(config.model_params.upsample_factors)
|
||||
|
@ -258,21 +258,22 @@ class Wavegrad(BaseVocoder):
|
|||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
|
||||
def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
|
@ -307,8 +308,9 @@ class Wavegrad(BaseVocoder):
|
|||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.configs import WavegradConfig
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = WavegradConfig(
|
||||
|
@ -22,6 +26,24 @@ config = WavegradConfig(
|
|||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||
|
||||
# init model
|
||||
model = Wavegrad(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
||||
|
|
Loading…
Reference in New Issue