a ton of linter updates

pull/367/head
Eren Gölge 2021-03-08 05:06:54 +01:00
parent 4422642ec0
commit 9a48ba3821
45 changed files with 244 additions and 241 deletions

View File

@ -170,7 +170,7 @@ def main():
args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
if args.vocoder_name is not None:
vocoder_path, vocoder_config_path, vocoder_item = manager.download_model(args.vocoder_name)
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: load custome models
if args.model_path is not None:

View File

@ -573,7 +573,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r, OUT_PATH,
global_step, epoch, c.r, OUT_PATH, model_characters,
keep_all_best=keep_all_best, keep_after=keep_after)

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import glob
import os
import sys
import time
@ -535,7 +533,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r, OUT_PATH,
global_step, epoch, c.r, OUT_PATH, model_characters,
keep_all_best=keep_all_best, keep_after=keep_after)

View File

@ -648,12 +648,14 @@ def main(args): # pylint: disable=redefined-outer-name
epoch,
c.r,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after,
scaler=scaler.state_dict() if c.mixed_precision else None
)
if __name__ == '__main__':
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(

View File

@ -50,7 +50,7 @@ def setup_loader(ap, is_val=False, verbose=False):
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size,
shuffle=False if num_gpus > 1 else True,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers

View File

@ -59,7 +59,7 @@ if args.list_models:
# set models by the released models
if args.model_name is not None:
tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name)
args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
if args.vocoder_name is not None:
vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name)

View File

@ -1,7 +1,7 @@
import collections
import os
import random
from multiprocessing import Manager, Pool
from multiprocessing import Pool
import numpy as np
import torch

View File

@ -3,7 +3,7 @@ from glob import glob
import re
import sys
from pathlib import Path
from typing import List, Tuple
from typing import List
from tqdm import tqdm

View File

@ -367,18 +367,18 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
beta (float, optional): [description]. Defaults to 0.9 from the paper.
"""
def __init__(
self,
query_dim,
embedding_dim, # pylint: disable=unused-argument
attention_dim,
static_filter_dim,
static_kernel_size,
dynamic_filter_dim,
dynamic_kernel_size,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
):
self,
query_dim,
embedding_dim, # pylint: disable=unused-argument
attention_dim,
static_filter_dim,
static_kernel_size,
dynamic_filter_dim,
dynamic_kernel_size,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
):
super().__init__()
self._mask_value = 1e-8
self.dynamic_filter_dim = dynamic_filter_dim
@ -402,7 +402,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
self.v = nn.Linear(attention_dim, 1, bias=False)
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
alpha, beta)
alpha, beta)
self.register_buffer("prior", torch.FloatTensor(prior).flip(0))
# pylint: disable=unused-argument

View File

@ -97,7 +97,7 @@ class ResidualConv1dBNBlock(nn.Module):
assert len(dilations) == num_res_blocks
self.res_blocks = nn.ModuleList()
for idx, dilation in enumerate(dilations):
block = Conv1dBNBlock(in_channels if idx==0 else hidden_channels,
block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == len(dilations) else hidden_channels,
hidden_channels,
kernel_size,

View File

@ -98,11 +98,11 @@ class Encoder(nn.Module):
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = RelativePositionTransformer(hidden_channels,
hidden_channels,
hidden_channels,
@ -125,11 +125,11 @@ class Encoder(nn.Module):
elif encoder_type.lower() == 'time_depth_separable':
if use_prenet:
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
hidden_channels,
hidden_channels,

View File

@ -366,8 +366,10 @@ class RelativePositionTransformer(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.ffn_layers.append(
FeedForwardNetwork(hidden_channels,
hidden_channels if (idx + 1) != self.num_layers else out_channels,
FeedForwardNetwork(
hidden_channels,
hidden_channels if
(idx + 1) != self.num_layers else out_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p))

View File

@ -75,7 +75,7 @@ class ReferenceEncoder(nn.Module):
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
self.recurrence.flatten_parameters()
memory, out = self.recurrence(x)
_, out = self.recurrence(x)
# out: 3D tensor [seq_len==1, batch_size, encoding_size=128]
return out.squeeze(0)

View File

@ -2,13 +2,12 @@ import math
import numpy as np
import torch
from torch import nn
from inspect import signature
from torch.nn import functional
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.utils.ssim import ssim
# pylint: disable=abstract-method Method
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
class L1LossMasked(nn.Module):
def __init__(self, seq_len_norm):
@ -165,7 +164,7 @@ class BCELossMasked(nn.Module):
target.requires_grad = False
if length is not None:
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
max_len=target.size(1)).float()
x = x * mask
target = target * mask
num_items = mask.sum()
@ -310,10 +309,10 @@ class TacotronLoss(torch.nn.Module):
if self.postnet_alpha > 0:
if self.config.model in ["Tacotron", "TacotronGST"]:
postnet_loss = self.criterion(postnet_output, linear_input,
output_lens)
output_lens)
else:
postnet_loss = self.criterion(postnet_output, mel_input,
output_lens)
output_lens)
else:
if self.decoder_alpha > 0:
decoder_loss = self.criterion(decoder_output, mel_input)

View File

@ -146,17 +146,17 @@ class Decoder(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
out_channels,
in_hidden_channels,
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
c_in_channels=0):
self,
out_channels,
in_hidden_channels,
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
c_in_channels=0):
super().__init__()
if decoder_type == 'transformer':

View File

@ -73,13 +73,12 @@ class RelativePositionTransformerEncoder(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1]
)
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1])
self.rel_pos_transformer = RelativePositionTransformer(
hidden_channels, out_channels, hidden_channels, **params)
@ -104,9 +103,8 @@ class ResidualConv1dBNEncoder(nn.Module):
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = nn.Sequential(
nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU())
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels, **params)
@ -162,17 +160,17 @@ class Encoder(nn.Module):
}
"""
def __init__(
self,
in_hidden_channels,
out_channels,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
c_in_channels=0):
self,
in_hidden_channels,
out_channels,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
c_in_channels=0):
super().__init__()
self.out_channels = out_channels
self.in_channels = in_hidden_channels
@ -183,10 +181,9 @@ class Encoder(nn.Module):
# init encoder
if encoder_type.lower() == "transformer":
# text encoder
self.encoder = RelativePositionTransformerEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params) # pylint: disable=unexpected-keyword-arg
self.encoder = RelativePositionTransformerEncoder(
in_hidden_channels, out_channels, in_hidden_channels,
encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn':
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
out_channels,

View File

@ -33,32 +33,32 @@ class SpeedySpeech(nn.Module):
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
"""
# pylint: disable=dangerous-default-value
# pylint: disable=dangerous-default-value
def __init__(
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):
super().__init__()
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
@ -171,7 +171,7 @@ class SpeedySpeech(nn.Module):
"""
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
return o_de, o_dr_log.squeeze(1), attn
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument

View File

@ -10,7 +10,7 @@ import re
import itertools
def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str:
def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
"""Convert numerical arabic numbers (0->9) to chinese hanzi numbers ( -> 九)
Args:
@ -32,7 +32,7 @@ def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str:
nd = str(num)
if abs(float(nd)) >= 1e48:
raise ValueError('number out of range')
elif 'e' in nd:
if 'e' in nd:
raise ValueError('scientific notation is not supported')
c_symbol = '正负点' if simp else '正負點'
if o: # formal
@ -69,7 +69,7 @@ def _num2chinese(num :str, big=False, simp=True, o=False, twoalt=False) -> str:
if int(unit) == 0: # 0000
intresult.append(c_basic[0])
continue
elif nu > 0 and int(unit) == 2: # 0002
if nu > 0 and int(unit) == 2: # 0002
intresult.append(c_twoalt + c_unit2[nu - 1])
continue
ulist = []

View File

@ -135,7 +135,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
return model
def is_tacotron(c):
return False if c['model'] in ['speedy_speech', 'glow_tts'] else True
return not c['model'] in ['speedy_speech', 'glow_tts']
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str)

View File

@ -7,7 +7,7 @@ from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False):
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
"""Load ```TTS.tts.models``` checkpoints.
Args:
@ -98,7 +98,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
epoch, r, output_folder, characters, **kwargs):
epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.

View File

@ -63,8 +63,8 @@ def parse_speakers(c, args, meta_data_train, OUT_PATH):
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print(" > Training with {} speakers: {}".format(len(speakers),
", ".join(speakers)))
print(" > Training with {} speakers: {}".format(
len(speakers), ", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None

View File

@ -17,17 +17,22 @@ def create_window(window_size, channel):
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
@ -39,7 +44,7 @@ def _ssim(img1, img2, window, window_size, channel, size_average = True):
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
@ -64,7 +69,8 @@ class SSIM(torch.nn.Module):
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)

View File

@ -20,9 +20,13 @@ def text_to_seqvec(text, CONFIG):
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32)
else:
seq = np.asarray(
text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32)
seq = np.asarray(text_to_sequence(
text,
text_cleaner,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank']
if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32)
return seq
@ -77,9 +81,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
# distributed model
postnet_output, alignments= model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
else:
postnet_output, alignments= model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None

View File

@ -2,60 +2,60 @@ import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
# List of (regular expression, replacement) pairs for abbreviations in french:
abbreviations_fr = [(re.compile('\\b%s\\.?' % x[0], re.IGNORECASE), x[1])
for x in [
('M', 'monsieur'),
('Mlle', 'mademoiselle'),
('Mlles', 'mesdemoiselles'),
('Mme', 'Madame'),
('Mmes', 'Mesdames'),
('N.B', 'nota bene'),
('M', 'monsieur'),
('p.c.q', 'parce que'),
('Pr', 'professeur'),
('qqch', 'quelque chose'),
('rdv', 'rendez-vous'),
('max', 'maximum'),
('min', 'minimum'),
('no', 'numéro'),
('adr', 'adresse'),
('dr', 'docteur'),
('st', 'saint'),
('co', 'companie'),
('jr', 'junior'),
('sgt', 'sergent'),
('capt', 'capitain'),
('col', 'colonel'),
('av', 'avenue'),
('av. J.-C', 'avant Jésus-Christ'),
('apr. J.-C', 'après Jésus-Christ'),
('art', 'article'),
('boul', 'boulevard'),
('c.-à-d', 'cest-à-dire'),
('etc', 'et cetera'),
('ex', 'exemple'),
('excl', 'exclusivement'),
('boul', 'boulevard'),
]]
for x in [
('M', 'monsieur'),
('Mlle', 'mademoiselle'),
('Mlles', 'mesdemoiselles'),
('Mme', 'Madame'),
('Mmes', 'Mesdames'),
('N.B', 'nota bene'),
('M', 'monsieur'),
('p.c.q', 'parce que'),
('Pr', 'professeur'),
('qqch', 'quelque chose'),
('rdv', 'rendez-vous'),
('max', 'maximum'),
('min', 'minimum'),
('no', 'numéro'),
('adr', 'adresse'),
('dr', 'docteur'),
('st', 'saint'),
('co', 'companie'),
('jr', 'junior'),
('sgt', 'sergent'),
('capt', 'capitain'),
('col', 'colonel'),
('av', 'avenue'),
('av. J.-C', 'avant Jésus-Christ'),
('apr. J.-C', 'après Jésus-Christ'),
('art', 'article'),
('boul', 'boulevard'),
('c.-à-d', 'cest-à-dire'),
('etc', 'et cetera'),
('ex', 'exemple'),
('excl', 'exclusivement'),
('boul', 'boulevard'),
]]

View File

@ -22,7 +22,7 @@ class AttrDict(dict):
def read_json_with_comments(json_path):
# fallback to json
with open(json_path, "r", encoding = "utf-8") as f:
with open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r'\\\n', '', input_str)
@ -40,7 +40,7 @@ def load_config(config_path: str) -> AttrDict:
ext = os.path.splitext(config_path)[1]
if ext in (".yml", ".yaml"):
with open(config_path, "r", encoding = "utf-8") as f:
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
else:
data = read_json_with_comments(config_path)
@ -61,7 +61,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
"""
# copy config.json
copy_config_path = os.path.join(out_path, 'config.json')
config_lines = open(config_file, "r", encoding = "utf-8").readlines()
config_lines = open(config_file, "r", encoding="utf-8").readlines()
# add extra information fields
for key, value in new_fields.items():
if isinstance(value, str):

View File

@ -144,8 +144,3 @@ class ModelManager(object):
if isinstance(key, str) and len(my_dict[key]) > 0:
return True
return False

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import functional as F
class TorchSTFT(nn.Module):
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
""" Torch based STFT operation """
super(TorchSTFT, self).__init__()

View File

@ -22,8 +22,10 @@ class PositionalEncoding(nn.Module):
def forward(self, x, noise_level):
if x.shape[2] > self.pe.shape[1]:
self.init_pe_matrix(x.shape[1] ,x.shape[2], x)
return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C
self.init_pe_matrix(x.shape[1], x.shape[2], x)
return x + noise_level[..., None,
None] + self.pe[:, :x.size(2)].repeat(
x.shape[0], 1, 1) / self.C
def init_pe_matrix(self, n_channels, max_len, x):
pe = torch.zeros(max_len, n_channels)
@ -171,5 +173,4 @@ class DBlock(nn.Module):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)
self.main_block[idx] = weight_norm(layer)

View File

@ -79,7 +79,7 @@ class Wavegrad(nn.Module):
return x
def load_noise_schedule(self, path):
beta = np.load(path, allow_pickle=True).item()['beta']
beta = np.load(path, allow_pickle=True).item()['beta'] # pylint: disable=unexpected-keyword-arg
self.compute_noise_level(beta)
@torch.no_grad()
@ -91,8 +91,8 @@ class Wavegrad(nn.Module):
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n -
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(
y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z

View File

@ -73,15 +73,15 @@ class Stretch2d(nn.Module):
class UpsampleNetwork(nn.Module):
def __init__(
self,
feat_dims,
upsample_scales,
compute_dims,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
):
self,
feat_dims,
upsample_scales,
compute_dims,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
):
super().__init__()
self.total_scale = np.cumproduct(upsample_scales)[-1]
self.indent = pad * self.total_scale
@ -118,9 +118,8 @@ class UpsampleNetwork(nn.Module):
class Upsample(nn.Module):
def __init__(
self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
):
def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims,
res_out_dims, use_aux_net):
super().__init__()
self.scale = scale
self.pad = pad

View File

@ -44,9 +44,11 @@ def log_sum_exp(x):
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
def discretized_mix_logistic_loss(
y_hat, y, num_classes=65536, log_scale_min=None, reduce=True
):
def discretized_mix_logistic_loss(y_hat,
y,
num_classes=65536,
log_scale_min=None,
reduce=True):
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
y_hat = y_hat.permute(0, 2, 1)

View File

@ -7,7 +7,7 @@ import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False):
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
except ModuleNotFoundError:

View File

@ -166,7 +166,7 @@ class SSIMLossTests(unittest.TestCase):
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
assert abs(output.item() - 1.0) < 1e-4 , "1.0 vs {}".format(output.item())
assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item())
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
@ -217,4 +217,3 @@ class SSIMLossTests(unittest.TestCase):
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 0, "0 vs {}".format(output.item())

View File

@ -161,7 +161,7 @@ def test_speedy_speech():
x_lengths,
y_lengths,
durations,
g=torch.rand((B,256)).to(device))
g=torch.rand((B, 256)).to(device))
assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}"
assert list(attn.shape) == [B, T_de, T_en]

View File

@ -356,4 +356,3 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1

View File

@ -17,5 +17,5 @@ def test_currency() -> None:
def test_expand_numbers() -> None:
assert "minus one" == phoneme_cleaners("-1")
assert "one" == phoneme_cleaners("1")
assert phoneme_cleaners("-1") == 'minus one'
assert phoneme_cleaners("1") == 'one'

View File

@ -17,7 +17,7 @@ def test_phoneme_to_sequence():
lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = 'ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!'
assert text_hat == text_hat_with_params == gt

View File

@ -20,18 +20,18 @@ class WavegradTrainTest(unittest.TestCase):
criterion = torch.nn.L1Loss().to(device)
model = Wavegrad(in_channels=80,
out_channels=1,
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
[1, 2, 4, 8], [1, 2, 4, 8],
[1, 2, 4, 8]])
out_channels=1,
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
[1, 2, 4, 8], [1, 2, 4, 8],
[1, 2, 4, 8]])
model_ref = Wavegrad(in_channels=80,
out_channels=1,
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
[1, 2, 4, 8], [1, 2, 4, 8],
[1, 2, 4, 8]])
out_channels=1,
upsample_factors=[5, 5, 3, 2, 2],
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2],
[1, 2, 4, 8], [1, 2, 4, 8],
[1, 2, 4, 8]])
model.train()
model.to(device)
betas = np.linspace(1e-6, 1e-2, 1000)