Make stlye

pull/1324/head
Eren Gölge 2022-02-20 12:37:27 +01:00
parent acc83cd3e6
commit 1e414b3a09
8 changed files with 8 additions and 19 deletions

View File

@ -10,8 +10,6 @@ import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field

View File

@ -5,11 +5,11 @@ import torch
from coqpit import Coqpit
from torch import nn
# pylint: skip-file
class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
"""
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
@staticmethod
@abstractmethod

View File

@ -1,20 +1,15 @@
from asyncio.log import logger
from dataclasses import dataclass, field
import os
from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs
from trainer import TrainerArgs, get_last_checkpoint
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from trainer import TrainerArgs, get_last_checkpoint
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
@dataclass

View File

@ -1,9 +1,8 @@
import collections
import math
import os
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@ -545,7 +544,8 @@ class Vits(BaseTTS):
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,):
language_manager: LanguageManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
@ -1483,6 +1483,7 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
##################################
# VITS CHARACTERS
##################################

View File

@ -1,4 +1,3 @@
from abc import ABC
from dataclasses import replace
from typing import Dict

View File

@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
y = y.unsqueeze(1)
return {"input": m, "waveform": y}
def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
):
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,