mirror of https://github.com/coqui-ai/TTS.git
refix linter
parent
7d92b30946
commit
c79a82ed07
|
@ -22,7 +22,6 @@ from torch.utils.data import DataLoader
|
|||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.callbacks import TrainerCallback
|
||||
|
@ -41,6 +40,7 @@ from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
|||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
|
@ -766,14 +766,15 @@ class Trainer:
|
|||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if isinstance(self.model, Wavegrad):
|
||||
return None # TODO: Fix inference on WaveGrad
|
||||
elif hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
return None # TODO: Fix inference on WaveGrad
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda)
|
||||
else:
|
||||
figures, audios = self.model.test_run(self.ap, self.use_cuda)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
return None
|
||||
|
||||
def _fit(self) -> None:
|
||||
"""🏃 train -> evaluate -> test for the number of epochs."""
|
||||
|
|
|
@ -113,7 +113,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
""" Compute and format the mode outputs with the given alignment map"""
|
||||
"""Compute and format the mode outputs with the given alignment map"""
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
|
|
@ -31,7 +31,7 @@ def setup_model(config: Coqpit):
|
|||
|
||||
|
||||
def setup_generator(c):
|
||||
""" TODO: use config object as arguments"""
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
|
@ -94,7 +94,7 @@ def setup_generator(c):
|
|||
|
||||
|
||||
def setup_discriminator(c):
|
||||
""" TODO: use config objekt as arguments"""
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
|
|
|
@ -261,7 +261,9 @@ class Wavegrad(BaseModel):
|
|||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda): # pylint: disable=unused-argument
|
||||
def test_run(
|
||||
self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda
|
||||
): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
|
|
|
@ -571,7 +571,7 @@ class Wavernn(BaseVocoder):
|
|||
|
||||
@torch.no_grad()
|
||||
def test_run(
|
||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument
|
||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
figures = {}
|
||||
audios = {}
|
||||
|
|
|
@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli
|
|||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
|
||||
|
||||
|
||||
def run_test_train():
|
||||
command = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
||||
|
@ -17,6 +18,7 @@ def run_test_train():
|
|||
)
|
||||
run_cli(command)
|
||||
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
|
Loading…
Reference in New Issue