mirror of https://github.com/coqui-ai/TTS.git
Make stlye
parent
acc83cd3e6
commit
1e414b3a09
|
@ -10,8 +10,6 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from trainer.torch import NoamLR
|
from trainer.torch import NoamLR
|
||||||
|
|
||||||
from trainer.torch import NoamLR
|
|
||||||
|
|
||||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainerModel(ABC, nn.Module):
|
class BaseTrainerModel(ABC, nn.Module):
|
||||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -1,20 +1,15 @@
|
||||||
from asyncio.log import logger
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer import TrainerArgs
|
from trainer import TrainerArgs, get_last_checkpoint
|
||||||
from trainer.logging import logger_factory
|
from trainer.logging import logger_factory
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from trainer import TrainerArgs, get_last_checkpoint
|
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
from trainer.logging import logger_factory
|
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import collections
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -545,7 +544,8 @@ class Vits(BaseTTS):
|
||||||
ap: "AudioProcessor" = None,
|
ap: "AudioProcessor" = None,
|
||||||
tokenizer: "TTSTokenizer" = None,
|
tokenizer: "TTSTokenizer" = None,
|
||||||
speaker_manager: SpeakerManager = None,
|
speaker_manager: SpeakerManager = None,
|
||||||
language_manager: LanguageManager = None,):
|
language_manager: LanguageManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
|
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
|
||||||
|
|
||||||
|
@ -1483,6 +1483,7 @@ class Vits(BaseTTS):
|
||||||
language_manager = LanguageManager.init_from_config(config)
|
language_manager = LanguageManager.init_from_config(config)
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
# VITS CHARACTERS
|
# VITS CHARACTERS
|
||||||
##################################
|
##################################
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from abc import ABC
|
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
|
|
@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
|
||||||
y = y.unsqueeze(1)
|
y = y.unsqueeze(1)
|
||||||
return {"input": m, "waveform": y}
|
return {"input": m, "waveform": y}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
|
||||||
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
|
|
||||||
):
|
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
dataset = WaveGradDataset(
|
dataset = WaveGradDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
|
|
Loading…
Reference in New Issue