mirror of https://github.com/coqui-ai/TTS.git
Make linter
parent
0b1986384f
commit
37959ad0c7
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class TrainingArgs(Coqpit):
|
|||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
args: Union[Coqpit, Namespace],
|
||||
config: Coqpit,
|
||||
|
@ -335,7 +335,9 @@ class Trainer:
|
|||
args.parse_args(training_args)
|
||||
return args, coqpit_overrides
|
||||
|
||||
def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None):
|
||||
def init_training(
|
||||
self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None
|
||||
): # pylint: disable=no-self-use
|
||||
"""Initialize training and update model configs from command line arguments.
|
||||
|
||||
Args:
|
||||
|
@ -387,14 +389,13 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module:
|
||||
if isinstance(get_data_samples, Callable):
|
||||
if callable(get_data_samples):
|
||||
if len(signature(get_data_samples).sig.parameters) == 1:
|
||||
train_samples, eval_samples = get_data_samples(config)
|
||||
else:
|
||||
train_samples, eval_samples = get_data_samples()
|
||||
return train_samples, eval_samples
|
||||
else:
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
def restore_model(
|
||||
self,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -13,7 +12,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
@ -360,7 +358,7 @@ class AlignTTS(BaseTTS):
|
|||
|
||||
return outputs, loss_dict
|
||||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from coqpit import MISSING, Coqpit
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -14,7 +14,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
|||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class Tacotron(BaseTacotron):
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -11,7 +9,6 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class Tacotron2(BaseTacotron):
|
||||
|
|
|
@ -17,7 +17,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, se
|
|||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -576,7 +575,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return outputs, loss_dict
|
||||
|
||||
def _log(self, ap, batch, outputs, name_prefix="train"):
|
||||
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
|
||||
y_hat = outputs[0]["model_outputs"]
|
||||
y = outputs[0]["waveform_seg"]
|
||||
figures = plot_results(y_hat, y, ap, name_prefix)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import MISSING
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseTrainingConfig
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from torch.nn.utils import weight_norm
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
|
|
|
@ -6,7 +6,6 @@ from TTS.tts.configs import SpeedySpeechConfig
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.forward_tts import ForwardTTS
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(
|
||||
|
|
Loading…
Reference in New Issue