mirror of https://github.com/coqui-ai/TTS.git
Make stlye
parent
acc83cd3e6
commit
1e414b3a09
|
@ -10,8 +10,6 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTrainerModel(ABC, nn.Module):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||
"""
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
|
|
@ -1,20 +1,15 @@
|
|||
from asyncio.log import logger
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import collections
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -545,7 +544,8 @@ class Vits(BaseTTS):
|
|||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,):
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
|
@ -1483,6 +1483,7 @@ class Vits(BaseTTS):
|
|||
language_manager = LanguageManager.init_from_config(config)
|
||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||
|
||||
|
||||
##################################
|
||||
# VITS CHARACTERS
|
||||
##################################
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from abc import ABC
|
||||
from dataclasses import replace
|
||||
from typing import Dict
|
||||
|
||||
|
|
|
@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
|
|||
y = y.unsqueeze(1)
|
||||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
|
|
Loading…
Reference in New Issue