Make linter

pull/847/head
Eren Gölge 2021-09-30 23:02:16 +00:00
parent 0b1986384f
commit 37959ad0c7
12 changed files with 9 additions and 26 deletions

View File

@ -1,5 +1,4 @@
import os
from typing import List, Union
from coqpit import Coqpit

View File

@ -90,7 +90,7 @@ class TrainingArgs(Coqpit):
class Trainer:
def __init__(
def __init__( # pylint: disable=dangerous-default-value
self,
args: Union[Coqpit, Namespace],
config: Coqpit,
@ -335,7 +335,9 @@ class Trainer:
args.parse_args(training_args)
return args, coqpit_overrides
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
def init_training(
self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None
): # pylint: disable=no-self-use
"""Initialize training and update model configs from command line arguments.
Args:
@ -387,14 +389,13 @@ class Trainer:
@staticmethod
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
if isinstance(get_data_samples, Callable):
if callable(get_data_samples):
if len(signature(get_data_samples).sig.parameters) == 1:
train_samples, eval_samples = get_data_samples(config)
else:
train_samples, eval_samples = get_data_samples()
return train_samples, eval_samples
else:
return None, None
return None, None
def restore_model(
self,

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
@ -13,7 +12,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@ -360,7 +358,7 @@ class AlignTTS(BaseTTS):
return outputs, loss_dict
def _create_logs(self, batch, outputs, ap):
def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]

View File

@ -1,17 +1,15 @@
import copy
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List
import torch
from coqpit import MISSING, Coqpit
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler

View File

@ -14,7 +14,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
from TTS.utils.audio import AudioProcessor
@dataclass

View File

@ -14,7 +14,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec

View File

@ -1,7 +1,5 @@
# coding: utf-8
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class Tacotron(BaseTacotron):

View File

@ -1,7 +1,5 @@
# coding: utf-8
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class Tacotron2(BaseTacotron):

View File

@ -17,7 +17,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -576,7 +575,7 @@ class Vits(BaseTTS):
)
return outputs, loss_dict
def _log(self, ap, batch, outputs, name_prefix="train"):
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
y_hat = outputs[0]["model_outputs"]
y = outputs[0]["waveform_seg"]
figures = plot_results(y_hat, y, ap, name_prefix)

View File

@ -1,7 +1,5 @@
from dataclasses import dataclass, field
from coqpit import MISSING
from TTS.config import BaseAudioConfig, BaseTrainingConfig

View File

@ -9,7 +9,6 @@ from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset

View File

@ -6,7 +6,6 @@ from TTS.tts.configs import SpeedySpeechConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
from TTS.utils.audio import AudioProcessor
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(