mirror of https://github.com/coqui-ai/TTS.git
Update pylint 2.10.2 and fix lint issues
parent
ccef20bff9
commit
18da8f5dbd
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
|
||||
with open(os.path.join(os.path.dirname(__file__), "VERSION"), 'r', encoding='utf-8') as f:
|
||||
version = f.read().strip()
|
||||
|
||||
__version__ = version
|
||||
|
|
|
@ -158,7 +158,7 @@ Example run:
|
|||
# ourput metafile
|
||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||
|
||||
with open(metafile, "w") as f:
|
||||
with open(metafile, "w", encoding="utf-8") as f:
|
||||
for p in file_paths:
|
||||
f.write(f"{p[0]}|{p[1]}\n")
|
||||
print(f" >> Metafile created: {metafile}")
|
||||
|
|
|
@ -215,7 +215,7 @@ def extract_spectrograms(
|
|||
wav = ap.inv_melspectrogram(mel)
|
||||
ap.save_wav(wav, wav_gl_path)
|
||||
|
||||
with open(os.path.join(output_path, metada_name), "w") as f:
|
||||
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
|
||||
for data in export_metadata:
|
||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||
|
||||
|
|
18
TTS/model.py
18
TTS/model.py
|
@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
|
||||
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
|
||||
"""Forward pass for the model mainly used in training.
|
||||
|
||||
You can be flexible here and use different number of arguments and argument names since it is mostly used by
|
||||
`train_step()` in training whitout exposing it to the out of the class.
|
||||
You can be flexible here and use different number of arguments and argument names since it is intended to be
|
||||
used by `train_step()` without exposing it out of the model.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): Input text character sequence ids.
|
||||
input (torch.Tensor): Input tensor.
|
||||
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
|
||||
for the model.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
|
||||
Dict: Model outputs. Main model output must be named as "model_outputs".
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
|
||||
After the model is trained this is the only function that connects the model the out world.
|
||||
|
||||
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): [description]
|
||||
input (torch.Tensor): [description]
|
||||
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
|
|
@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
|
|||
extract_path = zip_filepath.strip(".zip")
|
||||
|
||||
# check zip file md5sum
|
||||
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
|
||||
with open(zip_filepath, "rb") as f_zip:
|
||||
md5 = hashlib.md5(f_zip.read()).hexdigest()
|
||||
if md5 != MD5SUM[subset]:
|
||||
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||
|
||||
|
|
|
@ -631,13 +631,13 @@ class Trainer:
|
|||
outputs = outputs_per_optimizer
|
||||
|
||||
# update avg runtime stats
|
||||
keep_avg_update = dict()
|
||||
keep_avg_update = {}
|
||||
keep_avg_update["avg_loader_time"] = loader_time
|
||||
keep_avg_update["avg_step_time"] = step_time
|
||||
self.keep_avg_train.update_values(keep_avg_update)
|
||||
|
||||
# update avg loss stats
|
||||
update_eval_values = dict()
|
||||
update_eval_values = {}
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_train.update_values(update_eval_values)
|
||||
|
@ -797,7 +797,7 @@ class Trainer:
|
|||
loss_dict = self._detach_loss_dict(loss_dict)
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
update_eval_values = {}
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_eval.update_values(update_eval_values)
|
||||
|
@ -977,12 +977,13 @@ class Trainer:
|
|||
def __init__(self, print_to_terminal=True):
|
||||
self.print_to_terminal = print_to_terminal
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(log_file, "a")
|
||||
self.log_file = log_file
|
||||
|
||||
def write(self, message):
|
||||
if self.print_to_terminal:
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
with open(self.log_file, "a", encoding="utf-8") as f:
|
||||
f.write(message)
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
|
|
|
@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):
|
|||
|
||||
def load_attention_mask_meta_data(metafile_path):
|
||||
"""Load meta data file created by compute_attention_masks.py"""
|
||||
with open(metafile_path, "r") as f:
|
||||
with open(metafile_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
meta_data = []
|
||||
|
|
|
@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "tweb"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
|
@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = cols[1].strip()
|
||||
|
@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
|
|||
continue
|
||||
speaker_name = speaker_name_match.group("speaker_name")
|
||||
print(" | > {}".format(csv_file))
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
if meta_files is None:
|
||||
|
@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
|
|||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
||||
|
@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
|
|||
for idx, line in enumerate(ttf):
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||
return items
|
||||
|
||||
|
@ -158,7 +158,7 @@ def css10(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
|
@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "nancy"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
utt_id = line.split()[1]
|
||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||
|
@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
|
|||
"""Normalize the common voice meta data file to TTS format."""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("client_id"):
|
||||
continue
|
||||
|
@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):
|
|||
|
||||
for meta_file in meta_files:
|
||||
_meta_file = os.path.basename(meta_file).split(".")[0]
|
||||
with open(meta_file, "r") as ttf:
|
||||
with open(meta_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("\t")
|
||||
file_name = cols[0]
|
||||
|
@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
|
|||
"""BRSpeech 3.0 beta"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("wav_filename"):
|
||||
continue
|
||||
|
@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
|||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
continue
|
||||
with open(meta_file) as file_text:
|
||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
|
@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
|||
def mls(root_path, meta_files=None):
|
||||
"""http://www.openslr.org/94/"""
|
||||
items = []
|
||||
with open(os.path.join(root_path, meta_files), "r") as meta:
|
||||
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
|
||||
for line in meta:
|
||||
file, text = line.split("\t")
|
||||
text = text[:-1]
|
||||
|
@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
|
||||
# if not exists meta file, crawl recursively for 'wav' files
|
||||
if meta_file is not None:
|
||||
with open(str(meta_file), "r") as f:
|
||||
with open(str(meta_file), "r", encoding="utf-8") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
elif not cache_to.exists():
|
||||
|
@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
|
||||
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
|
||||
cnt += 1
|
||||
with open(str(cache_to), "w") as f:
|
||||
with open(str(cache_to), "w", encoding="utf-8") as f:
|
||||
f.write("".join(meta_data))
|
||||
if cnt < expected_count:
|
||||
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
|
||||
|
||||
with open(str(cache_to), "r") as f:
|
||||
with open(str(cache_to), "r", encoding="utf-8") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
|
||||
|
@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "baker"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
|
@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "kokoro"
|
||||
with open(txt_file, "r") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class FFTransformer(nn.Module):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GST(nn.Module):
|
||||
|
|
|
@ -388,8 +388,8 @@ class Decoder(nn.Module):
|
|||
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
for idx, decoder_rnn in enumerate(self.decoder_rnns):
|
||||
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
# Residual connection
|
||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
|
|
|
@ -2,8 +2,8 @@ from dataclasses import dataclass, field
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
|
|
@ -435,7 +435,7 @@ class Vits(BaseTTS):
|
|||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
loss_duration = loss_duration/ torch.sum(x_mask)
|
||||
loss_duration = loss_duration / torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
|
@ -579,7 +579,7 @@ class Vits(BaseTTS):
|
|||
scores_disc_fake=outputs["scores_disc_fake"],
|
||||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"]
|
||||
loss_duration=outputs["loss_duration"],
|
||||
)
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
|
|
|
@ -18,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
|
|||
|
||||
# Initialize distributed communication
|
||||
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
|
||||
|
||||
|
||||
def apply_gradient_allreduce(module):
|
||||
|
||||
# sync model parameters
|
||||
for p in module.state_dict().values():
|
||||
if not torch.is_tensor(p):
|
||||
continue
|
||||
dist.broadcast(p, 0)
|
||||
|
||||
def allreduce_params():
|
||||
if module.needs_reduction:
|
||||
module.needs_reduction = False
|
||||
# bucketing params based on value types
|
||||
buckets = {}
|
||||
for param in module.parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
tp = type(param.data)
|
||||
if tp not in buckets:
|
||||
buckets[tp] = []
|
||||
buckets[tp].append(param)
|
||||
for tp in buckets:
|
||||
bucket = buckets[tp]
|
||||
grads = [param.grad.data for param in bucket]
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
|
||||
coalesced /= dist.get_world_size()
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
for param in list(module.parameters()):
|
||||
|
||||
def allreduce_hook(*_):
|
||||
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
|
||||
|
||||
if param.requires_grad:
|
||||
param.register_hook(allreduce_hook)
|
||||
|
||||
def set_needs_reduction(self, *_):
|
||||
self.needs_reduction = True
|
||||
|
||||
module.register_forward_hook(set_needs_reduction)
|
||||
return module
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
|
|
@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = list()
|
||||
feats = list()
|
||||
scores = []
|
||||
feats = []
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
|
|
|
@ -5,9 +5,9 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ def process_meta_data(path):
|
|||
meta_data = {}
|
||||
|
||||
# load meta data
|
||||
with open(path, "r") as f:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = csv.reader(f, delimiter="|")
|
||||
for row in data:
|
||||
frames = int(row[2])
|
||||
|
@ -92,7 +92,7 @@ def save_training(file_path, meta_data):
|
|||
rows.append(d["row"] + "\n")
|
||||
|
||||
random.shuffle(rows)
|
||||
with open(file_path, "w+") as f:
|
||||
with open(file_path, "w+", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(row)
|
||||
|
||||
|
@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
|||
|
||||
phonemes = {}
|
||||
|
||||
with open(train_path, "r") as f:
|
||||
with open(train_path, "r", encoding="utf-8") as f:
|
||||
data = csv.reader(f, delimiter="|")
|
||||
phonemes["None"] = 0
|
||||
for row in data:
|
||||
|
@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
|||
phonemes["None"] += 1
|
||||
|
||||
x, y = [], []
|
||||
for key in phonemes:
|
||||
x.append(key)
|
||||
y.append(phonemes[key])
|
||||
for k, v in phonemes.items():
|
||||
x.append(k)
|
||||
y.append(v)
|
||||
|
||||
plt.figure()
|
||||
plt.rcParams["figure.figsize"] = (50, 20)
|
||||
|
|
|
@ -29,7 +29,7 @@ config = VitsConfig(
|
|||
run_name="vits_ljspeech",
|
||||
batch_size=48,
|
||||
eval_batch_size=16,
|
||||
batch_group_size=0,
|
||||
batch_group_size=5,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
|
|
|
@ -2,4 +2,4 @@ black
|
|||
coverage
|
||||
isort
|
||||
nose
|
||||
pylint==2.8.3
|
||||
pylint==2.10.2
|
||||
|
|
Loading…
Reference in New Issue