From 424d04e4f6dfca0cb34c7957a970734304cdfb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 12:37:27 +0100 Subject: [PATCH] Make stlye --- TTS/bin/distribute.py | 1 - TTS/bin/train_encoder.py | 1 - TTS/bin/train_tts.py | 2 +- TTS/bin/train_vocoder.py | 2 +- TTS/model.py | 16 +-- TTS/speaker_encoder/utils/training.py | 9 +- TTS/tts/datasets/dataset.py | 2 - TTS/tts/layers/losses.py | 4 +- TTS/tts/models/base_tts.py | 8 +- TTS/tts/models/glow_tts.py | 3 +- TTS/tts/models/vits.py | 121 ++++++++---------- TTS/tts/utils/helpers.py | 1 + TTS/tts/utils/text/characters.py | 1 - TTS/tts/utils/text/punctuation.py | 2 +- TTS/vocoder/models/wavegrad.py | 6 +- .../multilingual/vits_tts/train_vits_tts.py | 4 +- tests/tts_tests/test_glow_tts.py | 2 +- tests/tts_tests/test_vits.py | 29 +++-- 18 files changed, 103 insertions(+), 111 deletions(-) diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 40f60d5d..97e2f0e3 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -7,7 +7,6 @@ import subprocess import time import torch - from trainer import TrainerArgs diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index f19966ee..5828411c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,7 +8,6 @@ import traceback import torch from torch.utils.data import DataLoader - from trainer.torch import NoamLR from TTS.speaker_encoder.dataset import SpeakerEncoderDataset diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 467685b2..31813712 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index c52fd962..32ecd7bd 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs diff --git a/TTS/model.py b/TTS/model.py index d7bd4f9f..39cbeabc 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -5,11 +5,11 @@ import torch from coqpit import Coqpit from torch import nn +# pylint: skip-file class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - """ + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @staticmethod @abstractmethod @@ -63,7 +63,7 @@ class BaseTrainerModel(ABC, nn.Module): """ return batch - def format_batch_on_device(self, batch:Dict) -> Dict: + def format_batch_on_device(self, batch: Dict) -> Dict: """Format batch on device before sending it to the model. If not implemented, model uses the batch as is. @@ -124,7 +124,6 @@ class BaseTrainerModel(ABC, nn.Module): """The same as `train_log()`""" ... - @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: """Load a checkpoint and get ready for training or inference. @@ -148,13 +147,8 @@ class BaseTrainerModel(ABC, nn.Module): @abstractmethod def get_data_loader( - self, - config: Coqpit, - assets: Dict, - is_eval: True, - data_items: List, - verbose: bool, - num_gpus: int): + self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int + ): ... # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index 5c2de274..c64c46b7 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,16 +1,15 @@ -from asyncio.log import logger -from dataclasses import dataclass, field import os +from dataclasses import dataclass, field from coqpit import Coqpit +from trainer import TrainerArgs +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config -from trainer import TrainerArgs from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer.logging import logger_factory -from trainer.logging.console_logger import ConsoleLogger from TTS.utils.trainer_utils import get_last_checkpoint diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 865209c2..d4d1a7e5 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -769,5 +769,3 @@ class F0Dataset: print("\n") print(f"{indent}> F0Dataset ") print(f"{indent}| > Number of instances : {len(self.samples)}") - - diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 57d36717..e03cf084 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -672,7 +672,9 @@ class VitsDiscriminatorLoss(nn.Module): def forward(self, scores_disc_real, scores_disc_fake): loss = 0.0 return_dict = {} - loss_disc, loss_disc_real, _ = self.discriminator_loss(scores_real=scores_disc_real, scores_fake=scores_disc_fake) + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha loss = loss + return_dict["loss_disc"] return_dict["loss"] = loss diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 6dd7ca72..dd6539a5 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -26,8 +26,12 @@ class BaseTTS(BaseTrainerModel): """ def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, ): super().__init__() self.config = config diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 23eb48da..c30f043a 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -530,7 +530,8 @@ class GlowTTS(BaseTTS): self.store_inverse() assert not self.training - def get_criterion(self): + @staticmethod + def get_criterion(): from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel return GlowTTSLoss() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b7766e92..ec6c9e5b 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,9 +1,8 @@ -import collections import math import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -25,7 +24,7 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDuration from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -38,6 +37,7 @@ from TTS.vocoder.utils.generic_utils import plot_results # IO / Feature extraction ############################## +# pylint: disable=global-statement hann_window = {} mel_basis = {} @@ -200,7 +200,7 @@ class VitsDataset(TTSDataset): text, wav_file, speaker_name, language_name, _ = _parse_sample(item) raw_text = text - wav, sr = load_audio(wav_file) + wav, _ = load_audio(wav_file) wav_filename = os.path.basename(wav_file) token_ids = self.get_token_ids(idx, text) @@ -538,12 +538,14 @@ class Vits(BaseTTS): >>> model = Vits(config) """ - def __init__(self, + def __init__( + self, config: Coqpit, ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None,): + language_manager: LanguageManager = None, + ): super().__init__(config, ap, tokenizer, speaker_manager, language_manager) @@ -673,9 +675,9 @@ class Vits(BaseTTS): ) # pylint: disable=W0101,W0105 self.audio_transform = torchaudio.transforms.Resample( - orig_freq=self.config.audio.sample_rate, - new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], + ) def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -777,9 +779,9 @@ class Vits(BaseTTS): with torch.no_grad(): o_scale = torch.exp(-2 * logs_p) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] @@ -806,7 +808,7 @@ class Vits(BaseTTS): outputs["loss_duration"] = loss_duration return outputs, attn - def forward( + def forward( # pylint: disable=dangerous-default-value self, x: torch.tensor, x_lengths: torch.tensor, @@ -886,7 +888,7 @@ class Vits(BaseTTS): waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, - pad_short = True + pad_short=True, ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -929,7 +931,9 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}): + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + ): # pylint: disable=dangerous-default-value """ Note: To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. @@ -1023,7 +1027,6 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. @@ -1062,7 +1065,7 @@ class Vits(BaseTTS): ) # cache tensors for the generator pass - self.model_outputs_cache = outputs + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init # compute scores and features scores_disc_fake, _, scores_disc_real, _ = self.disc( @@ -1082,14 +1085,16 @@ class Vits(BaseTTS): # compute melspec segment with autocast(enabled=False): - mel_slice = segment(mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True) + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True + ) mel_slice_hat = wav_to_mel( - y = self.model_outputs_cache["model_outputs"].float(), - n_fft = self.config.audio.fft_size, - sample_rate = self.config.audio.sample_rate, - num_mels = self.config.audio.num_mels, - hop_length = self.config.audio.hop_length, - win_length = self.config.audio.win_length, + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, fmin=self.config.audio.mel_fmin, fmax=self.config.audio.mel_fmax, center=False, @@ -1097,7 +1102,7 @@ class Vits(BaseTTS): # compute discriminator scores and features scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( - self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] ) # compute losses @@ -1105,18 +1110,18 @@ class Vits(BaseTTS): loss_dict = criterion[optimizer_idx]( mel_slice_hat=mel_slice.float(), mel_slice=mel_slice_hat.float(), - z_p= self.model_outputs_cache["z_p"].float(), - logs_q= self.model_outputs_cache["logs_q"].float(), - m_p= self.model_outputs_cache["m_p"].float(), - logs_p= self.model_outputs_cache["logs_p"].float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), z_len=mel_lens, - scores_disc_fake= scores_disc_fake, - feats_disc_fake= feats_disc_fake, - feats_disc_real= feats_disc_real, - loss_duration= self.model_outputs_cache["loss_duration"], + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, - gt_spk_emb= self.model_outputs_cache["gt_spk_emb"], - syn_spk_emb= self.model_outputs_cache["syn_spk_emb"], + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], ) return self.model_outputs_cache, loss_dict @@ -1248,7 +1253,9 @@ class Vits(BaseTTS): test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) return {"figures": test_figures, "audios": test_audios} - def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) logger.test_figures(steps, outputs["figures"]) @@ -1273,7 +1280,11 @@ class Vits(BaseTTS): d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if self.language_manager is not None and self.language_manager.language_id_mapping and self.args.use_language_embedding: + if ( + self.language_manager is not None + and self.language_manager.language_id_mapping + and self.args.use_language_embedding + ): language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] if language_ids is not None: @@ -1289,16 +1300,14 @@ class Vits(BaseTTS): ac = self.config.audio # compute spectrograms - batch["spec"] = wav_to_spec( - batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False - ) + batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) batch["mel"] = spec_to_mel( - spec = batch["spec"], - n_fft = ac.fft_size, - num_mels = ac.num_mels, - sample_rate = ac.sample_rate, - fmin = ac.mel_fmin, - fmax = ac.mel_fmax, + spec=batch["spec"], + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, ) assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" @@ -1325,27 +1334,6 @@ class Vits(BaseTTS): if is_eval and not config.run_eval: loader = None else: - # setup multi-speaker attributes - speaker_id_mapping = None - d_vector_mapping = None - if hasattr(self, "speaker_manager") and self.speaker_manager is not None: - if hasattr(config, "model_args"): - speaker_id_mapping = ( - self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None - ) - d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None - config.use_d_vector_file = config.model_args.use_d_vector_file - else: - speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None - - # setup multi-lingual attributes - language_id_mapping = None - if hasattr(self, "language_manager"): - language_id_mapping = ( - self.language_manager.language_id_mapping if self.args.use_language_embedding else None - ) - # init dataloader dataset = VitsDataset( samples=samples, @@ -1495,6 +1483,7 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 1366c4a6..c2e7f561 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -119,6 +119,7 @@ def rand_segments( ret = segment(x, segment_indices, segment_size) return ret, segment_indices + def average_over_durations(values, durs): """Average values over durations. diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f6c04370..0ce65a90 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,4 +1,3 @@ -from abc import ABC from dataclasses import replace from typing import Dict diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py index 09087d5f..b2a058bb 100644 --- a/TTS/tts/utils/text/punctuation.py +++ b/TTS/tts/utils/text/punctuation.py @@ -57,7 +57,7 @@ class Punctuation: if not isinstance(value, six.string_types): raise ValueError("[!] Punctuations must be of type str.") self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder - self.puncs_regular_exp = re.compile(fr"(\s*[{re.escape(self._puncs)}]+\s*)+") + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") def strip(self, text): """Remove all the punctuations by replacing with `space`. diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 58fc8762..750258af 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -270,7 +270,7 @@ class Wavegrad(BaseVocoder): ) -> None: pass - def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument # setup noise schedule and inference ap = assets["audio_processor"] noise_schedule = self.config["test_noise_schedule"] @@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int - ): + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap, diff --git a/recipes/multilingual/vits_tts/train_vits_tts.py b/recipes/multilingual/vits_tts/train_vits_tts.py index ea4f377b..ac2c21a2 100644 --- a/recipes/multilingual/vits_tts/train_vits_tts.py +++ b/recipes/multilingual/vits_tts/train_vits_tts.py @@ -69,8 +69,8 @@ config = VitsConfig( print_eval=False, mixed_precision=False, sort_by_audio_len=True, - min_seq_len=32 * 256 * 4, - max_seq_len=160000, + min_audio_len=32 * 256 * 4, + max_audio_len=160000, output_path=output_path, datasets=dataset_config, characters={ diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 85b5ed7a..2783e4bd 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -4,6 +4,7 @@ import unittest import torch from torch import optim +from trainer.logging.tensorboard_logger import TensorboardLogger from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.tts.configs.glow_tts_config import GlowTTSConfig @@ -11,7 +12,6 @@ from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor -from trainer.logging.tensorboard_logger import TensorboardLogger # pylint: disable=unused-variable diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 4018c6bd..204ff2f7 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -3,15 +3,14 @@ import os import unittest import torch -from TTS.tts.datasets.formatters import ljspeech +from trainer.logging.tensorboard_logger import TensorboardLogger from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig -from TTS.tts.models.vits import Vits, VitsArgs, load_audio, amp_to_db, db_to_amp, wav_to_spec, wav_to_mel, spec_to_mel, VitsDataset +from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.utils.speakers import SpeakerManager -from trainer.logging.tensorboard_logger import TensorboardLogger LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") @@ -31,7 +30,17 @@ class TestVits(unittest.TestCase): self.assertEqual(sr, 22050) spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False) - mel = wav_to_mel(wav, n_fft=1024, num_mels=80, sample_rate=sr, hop_length=512, win_length=1024, fmin=0, fmax=8000, center=False) + mel = wav_to_mel( + wav, + n_fft=1024, + num_mels=80, + sample_rate=sr, + hop_length=512, + win_length=1024, + fmin=0, + fmax=8000, + center=False, + ) mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000) self.assertEqual((mel - mel2).abs().max(), 0) @@ -45,7 +54,7 @@ class TestVits(unittest.TestCase): def test_dataset(self): """TODO:""" - ... + ... def test_init_multispeaker(self): num_speakers = 10 @@ -164,7 +173,7 @@ class TestVits(unittest.TestCase): num_speakers = 0 config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) model = Vits(config).to(device) output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform) self._check_forward_outputs(config, output_dict) @@ -175,7 +184,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True) config.model_args.spec_segment_size = 10 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config) speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device) model = Vits(config).to(device) @@ -196,7 +205,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(model_args=args) model = Vits.init_from_config(config, verbose=False).to(device) model.train() - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) output_dict = model.forward( input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors} @@ -211,7 +220,7 @@ class TestVits(unittest.TestCase): args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10) config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device) @@ -246,7 +255,7 @@ class TestVits(unittest.TestCase): config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args) config.audio.sample_rate = 16000 - input_dummy, input_lengths, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) + input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)