linter fixes

pull/602/head
Eren Gölge 2021-05-28 14:04:51 +02:00
parent 79f7c5da1e
commit 26e7c0960c
10 changed files with 34 additions and 105 deletions

View File

@ -8,20 +8,20 @@ from TTS.utils.generic_utils import remove_experiment_folder
def main():
# try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
trainer.fit()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)
# try:
# sys.exit(0)
# except SystemExit:
# os._exit(0) # pylint: disable=protected-access
# except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH)
# traceback.print_exc()
# sys.exit(1)
try:
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=output_path)
trainer.fit()
except KeyboardInterrupt:
remove_experiment_folder(output_path)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(output_path)
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":

View File

@ -184,7 +184,7 @@ class TrainerTTS:
@staticmethod
def get_speaker_manager(
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
) -> SpeakerManager:
speaker_manager = SpeakerManager()
if restore_path:
@ -208,7 +208,9 @@ class TrainerTTS:
return speaker_manager
@staticmethod
def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
def get_scheduler(
config: Coqpit, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None:

View File

@ -275,7 +275,7 @@ class AlignTTS(nn.Module):
g: [B, C]
"""
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
# pad input to prevent dropping the last word
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
@ -314,7 +314,7 @@ class AlignTTS(nn.Module):
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]

View File

@ -143,7 +143,9 @@ class GlowTTS(nn.Module):
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
return y_mean, y_log_scale, o_attn_dur
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
def forward(
self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
x: [B, T]
@ -344,7 +346,7 @@ class GlowTTS(nn.Module):
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]

View File

@ -183,7 +183,7 @@ class SpeedySpeech(nn.Module):
g: [B, C]
"""
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
@ -226,7 +226,7 @@ class SpeedySpeech(nn.Module):
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
model_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
mel_input = batch["mel_input"]

View File

@ -79,7 +79,7 @@ class Tacotron(TacotronAbstract):
use_gst=False,
gst=None,
memory_size=5,
gradual_training=[],
gradual_training=None,
):
super().__init__(
num_chars,

View File

@ -1,5 +1,4 @@
# coding: utf-8
import numpy as np
import torch
from torch import nn
@ -77,7 +76,7 @@ class Tacotron2(TacotronAbstract):
speaker_embedding_dim=None,
use_gst=False,
gst=None,
gradual_training=[],
gradual_training=None,
):
super().__init__(
num_chars,

View File

@ -1,5 +1,4 @@
import copy
import logging
from abc import ABC, abstractmethod
import torch
@ -37,7 +36,7 @@ class TacotronAbstract(ABC, nn.Module):
speaker_embedding_dim=None,
use_gst=False,
gst=None,
gradual_training=[],
gradual_training=None,
):
"""Abstract Tacotron class"""
super().__init__()
@ -239,4 +238,4 @@ class TacotronAbstract(ABC, nn.Module):
trainer.model.decoder_backward.set_r(r)
trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True)
trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r)
logging.info(f"\n > Number of output frames: {self.decoder.r}")
print(f"\n > Number of output frames: {self.decoder.r}")

View File

@ -1,5 +1,4 @@
import json
import os
import random
from typing import Any, List, Union
@ -11,79 +10,6 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.utils.audio import AudioProcessor
def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
if os.path.splitext(out_path)[1] == ".json":
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)
with open(json_file) as f:
return json.load(f)
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = make_speakers_json_path(out_path)
with open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
def parse_speakers(c, args, meta_data_train, OUT_PATH):
"""Returns number of speakers, speaker embedding shape and speaker mapping"""
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
elif (
not c.use_external_speaker_embedding_file
): # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all(speaker in speaker_mapping for speaker in speakers), (
"As of now you, you cannot " "introduce new speakers to " "a previously trained model."
)
elif (
c.use_external_speaker_embedding_file and c.external_speaker_embedding_file
): # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
elif (
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
): # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
return num_speakers, speaker_embedding_dim, speaker_mapping
class SpeakerManager:
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
in a way that you can query. There are 3 different scenarios considered.

View File

@ -230,15 +230,16 @@ def synthesis(
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite":
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
model, text_inputs, CONFIG, speaker_id, style_mel
)
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
@ -252,7 +253,7 @@ def synthesis(
wav = trim_silence(wav, ap)
return_dict = {
"wav": wav,
"alignments": outputs["alignments"],
"alignments": alignments,
"model_outputs": model_outputs,
"text_inputs": text_inputs,
}