mirror of https://github.com/coqui-ai/TTS.git
linter fixes
parent
79f7c5da1e
commit
26e7c0960c
|
@ -8,20 +8,20 @@ from TTS.utils.generic_utils import remove_experiment_folder
|
|||
|
||||
|
||||
def main():
|
||||
# try:
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=output_path)
|
||||
trainer.fit()
|
||||
# except KeyboardInterrupt:
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
# try:
|
||||
# sys.exit(0)
|
||||
# except SystemExit:
|
||||
# os._exit(0) # pylint: disable=protected-access
|
||||
# except Exception: # pylint: disable=broad-except
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
# traceback.print_exc()
|
||||
# sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(output_path)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -184,7 +184,7 @@ class TrainerTTS:
|
|||
|
||||
@staticmethod
|
||||
def get_speaker_manager(
|
||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
|
||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
||||
) -> SpeakerManager:
|
||||
speaker_manager = SpeakerManager()
|
||||
if restore_path:
|
||||
|
@ -208,7 +208,9 @@ class TrainerTTS:
|
|||
return speaker_manager
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||
def get_scheduler(
|
||||
config: Coqpit, optimizer: torch.optim.Optimizer
|
||||
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||
lr_scheduler = config.lr_scheduler
|
||||
lr_scheduler_params = config.lr_scheduler_params
|
||||
if lr_scheduler is None:
|
||||
|
|
|
@ -275,7 +275,7 @@ class AlignTTS(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
|
||||
# pad input to prevent dropping the last word
|
||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
|
@ -314,7 +314,7 @@ class AlignTTS(nn.Module):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
|
|
@ -143,7 +143,9 @@ class GlowTTS(nn.Module):
|
|||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}):
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, cond_input={"x_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
|
@ -344,7 +346,7 @@ class GlowTTS(nn.Module):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
|
|
@ -183,7 +183,7 @@ class SpeedySpeech(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
|
||||
# input sequence should be greated than the max convolution size
|
||||
inference_padding = 5
|
||||
if x.shape[1] < 13:
|
||||
|
@ -226,7 +226,7 @@ class SpeedySpeech(nn.Module):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
|
|
@ -79,7 +79,7 @@ class Tacotron(TacotronAbstract):
|
|||
use_gst=False,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
gradual_training=[],
|
||||
gradual_training=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# coding: utf-8
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
@ -77,7 +76,7 @@ class Tacotron2(TacotronAbstract):
|
|||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=[],
|
||||
gradual_training=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import copy
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
@ -37,7 +36,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=[],
|
||||
gradual_training=None,
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
|
@ -239,4 +238,4 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
trainer.model.decoder_backward.set_r(r)
|
||||
trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True)
|
||||
trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r)
|
||||
logging.info(f"\n > Number of output frames: {self.decoder.r}")
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
@ -11,79 +10,6 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def make_speakers_json_path(out_path):
|
||||
"""Returns conventional speakers.json location."""
|
||||
return os.path.join(out_path, "speakers.json")
|
||||
|
||||
|
||||
def load_speaker_mapping(out_path):
|
||||
"""Loads speaker mapping if already present."""
|
||||
if os.path.splitext(out_path)[1] == ".json":
|
||||
json_file = out_path
|
||||
else:
|
||||
json_file = make_speakers_json_path(out_path)
|
||||
with open(json_file) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_speaker_mapping(out_path, speaker_mapping):
|
||||
"""Saves speaker mapping if not yet present."""
|
||||
if out_path is not None:
|
||||
speakers_json_path = make_speakers_json_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||
"""Returns number of speakers, speaker embedding shape and speaker mapping"""
|
||||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(meta_data_train)
|
||||
if args.restore_path:
|
||||
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
if not speaker_mapping:
|
||||
print(
|
||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
if not speaker_mapping:
|
||||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
not c.use_external_speaker_embedding_file
|
||||
): # if restore checkpoint and don't use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
speaker_embedding_dim = None
|
||||
assert all(speaker in speaker_mapping for speaker in speakers), (
|
||||
"As of now you, you cannot " "introduce new speakers to " "a previously trained model."
|
||||
)
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file and don't pass external embedding file
|
||||
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
else: # if start new train and don't use External Embedding file
|
||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||
speaker_embedding_dim = None
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
speaker_embedding_dim = None
|
||||
speaker_mapping = None
|
||||
|
||||
return num_speakers, speaker_embedding_dim, speaker_mapping
|
||||
|
||||
|
||||
class SpeakerManager:
|
||||
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
|
||||
in a way that you can query. There are 3 different scenarios considered.
|
||||
|
|
|
@ -230,15 +230,16 @@ def synthesis(
|
|||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, x_vector=x_vector)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
|
@ -252,7 +253,7 @@ def synthesis(
|
|||
wav = trim_silence(wav, ap)
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": outputs["alignments"],
|
||||
"alignments": alignments,
|
||||
"model_outputs": model_outputs,
|
||||
"text_inputs": text_inputs,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue