Update pylint 2.10.2 and fix lint issues

pull/725/head
Eren Gölge 2021-08-30 08:08:45 +00:00
parent ccef20bff9
commit 18da8f5dbd
24 changed files with 60 additions and 105 deletions

View File

@ -1,6 +1,6 @@
import os
with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f:
with open(os.path.join(os.path.dirname(__file__), "VERSION"), 'r', encoding='utf-8') as f:
version = f.read().strip()
__version__ = version

View File

@ -158,7 +158,7 @@ Example run:
# ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w") as f:
with open(metafile, "w", encoding="utf-8") as f:
for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}")

View File

@ -215,7 +215,7 @@ def extract_spectrograms(
wav = ap.inv_melspectrogram(mel)
ap.save_wav(wav, wav_gl_path)
with open(os.path.join(output_path, metada_name), "w") as f:
with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
for data in export_metadata:
f.write(f"{data[0]}|{data[1]+'.npy'}\n")

View File

@ -23,35 +23,31 @@ class BaseModel(nn.Module, ABC):
"""
@abstractmethod
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
"""Forward pass for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is mostly used by
`train_step()` in training whitout exposing it to the out of the class.
You can be flexible here and use different number of arguments and argument names since it is intended to be
used by `train_step()` without exposing it out of the model.
Args:
text (torch.Tensor): Input text character sequence ids.
input (torch.Tensor): Input tensor.
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
for the model.
Returns:
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
Dict: Model outputs. Main model output must be named as "model_outputs".
"""
outputs_dict = {"model_outputs": None}
...
return outputs_dict
@abstractmethod
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
"""Forward pass for inference.
After the model is trained this is the only function that connects the model the out world.
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
We don't use `*kwargs` since it is problematic with the TorchScript API.
Args:
text (torch.Tensor): [description]
input (torch.Tensor): [description]
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
Returns:

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
# adapted from https://github.com/cvqluu/GE2E-Loss

View File

@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from TTS.utils.io import load_fsspec

View File

@ -94,7 +94,8 @@ def download_and_extract(directory, subset, urls):
extract_path = zip_filepath.strip(".zip")
# check zip file md5sum
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
with open(zip_filepath, "rb") as f_zip:
md5 = hashlib.md5(f_zip.read()).hexdigest()
if md5 != MD5SUM[subset]:
raise ValueError("md5sum of %s mismatch" % zip_filepath)

View File

@ -631,13 +631,13 @@ class Trainer:
outputs = outputs_per_optimizer
# update avg runtime stats
keep_avg_update = dict()
keep_avg_update = {}
keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
@ -797,7 +797,7 @@ class Trainer:
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats
update_eval_values = dict()
update_eval_values = {}
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_eval.update_values(update_eval_values)
@ -977,12 +977,13 @@ class Trainer:
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log = open(log_file, "a")
self.log_file = log_file
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
self.log.write(message)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.

View File

@ -66,7 +66,7 @@ def load_meta_data(datasets, eval_split=True):
def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, "r") as f:
with open(metafile_path, "r", encoding="utf-8") as f:
lines = f.readlines()
meta_data = []

View File

@ -19,7 +19,7 @@ def tweb(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
@ -33,7 +33,7 @@ def mozilla(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = cols[1].strip()
@ -77,7 +77,7 @@ def mailabs(root_path, meta_files=None):
continue
speaker_name = speaker_name_match.group("speaker_name")
print(" | > {}".format(csv_file))
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
if meta_files is None:
@ -102,7 +102,7 @@ def ljspeech(root_path, meta_file):
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, speaker_name])
return items
@ -116,7 +116,7 @@ def ljspeech_test(root_path, meta_file):
for idx, line in enumerate(ttf):
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"])
return items
@ -158,7 +158,7 @@ def css10(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
@ -172,7 +172,7 @@ def nancy(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
@ -185,7 +185,7 @@ def common_voice(root_path, meta_file):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("client_id"):
continue
@ -208,7 +208,7 @@ def libri_tts(root_path, meta_files=None):
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split(".")[0]
with open(meta_file, "r") as ttf:
with open(meta_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("\t")
file_name = cols[0]
@ -245,7 +245,7 @@ def brspeech(root_path, meta_file):
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("wav_filename"):
continue
@ -268,7 +268,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48"):
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
continue
with open(meta_file) as file_text:
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
@ -295,7 +295,7 @@ def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
def mls(root_path, meta_files=None):
"""http://www.openslr.org/94/"""
items = []
with open(os.path.join(root_path, meta_files), "r") as meta:
with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta:
for line in meta:
file, text = line.split("\t")
text = text[:-1]
@ -329,7 +329,7 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
# if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None:
with open(str(meta_file), "r") as f:
with open(str(meta_file), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
elif not cache_to.exists():
@ -346,12 +346,12 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
cnt += 1
with open(str(cache_to), "w") as f:
with open(str(cache_to), "w", encoding="utf-8") as f:
f.write("".join(meta_data))
if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
with open(str(cache_to), "r") as f:
with open(str(cache_to), "r", encoding="utf-8") as f:
return [x.strip().split("|") for x in f.readlines()]
@ -367,7 +367,7 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "baker"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
@ -380,7 +380,7 @@ def kokoro(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "kokoro"
with open(txt_file, "r") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
class FFTransformer(nn.Module):

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
class GST(nn.Module):

View File

@ -388,8 +388,8 @@ class Decoder(nn.Module):
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
for idx, decoder_rnn in enumerate(self.decoder_rnns):
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input

View File

@ -2,8 +2,8 @@ from dataclasses import dataclass, field
from typing import Dict, Tuple
import torch
import torch.nn as nn
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder

View File

@ -435,7 +435,7 @@ class Vits(BaseTTS):
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = loss_duration/ torch.sum(x_mask)
loss_duration = loss_duration / torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
@ -579,7 +579,7 @@ class Vits(BaseTTS):
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"]
loss_duration=outputs["loss_duration"],
)
elif optimizer_idx == 1:

View File

@ -18,46 +18,3 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
# Initialize distributed communication
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
def apply_gradient_allreduce(module):
# sync model parameters
for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)
def allreduce_params():
if module.needs_reduction:
module.needs_reduction = False
# bucketing params based on value types
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(module.parameters()):
def allreduce_hook(*_):
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
if param.requires_grad:
param.register_hook(allreduce_hook)
def set_needs_reduction(self, *_):
self.needs_reduction = True
module.register_forward_hook(set_needs_reduction)
return module

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm

View File

@ -1,8 +1,8 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from TTS.utils.io import load_fsspec

View File

@ -40,8 +40,8 @@ class MelganMultiscaleDiscriminator(nn.Module):
)
def forward(self, x):
scores = list()
feats = list()
scores = []
feats = []
for disc in self.discriminators:
score, feat = disc(x)
scores.append(score)

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm
from TTS.utils.audio import TorchSTFT

View File

@ -5,9 +5,9 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

View File

@ -43,7 +43,7 @@ def process_meta_data(path):
meta_data = {}
# load meta data
with open(path, "r") as f:
with open(path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
for row in data:
frames = int(row[2])
@ -92,7 +92,7 @@ def save_training(file_path, meta_data):
rows.append(d["row"] + "\n")
random.shuffle(rows)
with open(file_path, "w+") as f:
with open(file_path, "w+", encoding="utf-8") as f:
for row in rows:
f.write(row)
@ -156,7 +156,7 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes = {}
with open(train_path, "r") as f:
with open(train_path, "r", encoding="utf-8") as f:
data = csv.reader(f, delimiter="|")
phonemes["None"] = 0
for row in data:
@ -174,9 +174,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
phonemes["None"] += 1
x, y = [], []
for key in phonemes:
x.append(key)
y.append(phonemes[key])
for k, v in phonemes.items():
x.append(k)
y.append(v)
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)

View File

@ -29,7 +29,7 @@ config = VitsConfig(
run_name="vits_ljspeech",
batch_size=48,
eval_batch_size=16,
batch_group_size=0,
batch_group_size=5,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,

View File

@ -2,4 +2,4 @@ black
coverage
isort
nose
pylint==2.8.3
pylint==2.10.2