Make stlye

pull/1324/head
Eren Gölge 2022-02-20 12:37:27 +01:00
parent acc83cd3e6
commit 1e414b3a09
8 changed files with 8 additions and 19 deletions

View File

@ -10,8 +10,6 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import NoamLR from trainer.torch import NoamLR
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field

View File

@ -5,11 +5,11 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
# pylint: skip-file
class BaseTrainerModel(ABC, nn.Module): class BaseTrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
"""
@staticmethod @staticmethod
@abstractmethod @abstractmethod

View File

@ -1,20 +1,15 @@
from asyncio.log import logger
from dataclasses import dataclass, field
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs from trainer import TrainerArgs, get_last_checkpoint
from trainer.logging import logger_factory from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from trainer import TrainerArgs, get_last_checkpoint
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
@dataclass @dataclass

View File

@ -1,9 +1,8 @@
import collections
import math import math
import os import os
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -545,7 +544,8 @@ class Vits(BaseTTS):
ap: "AudioProcessor" = None, ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None, tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None, speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,): language_manager: LanguageManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager, language_manager) super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
@ -1483,6 +1483,7 @@ class Vits(BaseTTS):
language_manager = LanguageManager.init_from_config(config) language_manager = LanguageManager.init_from_config(config)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
################################## ##################################
# VITS CHARACTERS # VITS CHARACTERS
################################## ##################################

View File

@ -1,4 +1,3 @@
from abc import ABC
from dataclasses import replace from dataclasses import replace
from typing import Dict from typing import Dict

View File

@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder):
y = y.unsqueeze(1) y = y.unsqueeze(1)
return {"input": m, "waveform": y} return {"input": m, "waveform": y}
def get_data_loader( def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int):
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
):
ap = assets["audio_processor"] ap = assets["audio_processor"]
dataset = WaveGradDataset( dataset = WaveGradDataset(
ap=ap, ap=ap,