Update TTS.tts formatters (#1228)

* Return Dict from tts formatters

* Make style
pull/1251/head
Eren Gölge 2022-02-11 23:03:43 +01:00 committed by GitHub
parent 5e3f499a69
commit 127118c637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 153 additions and 141 deletions

View File

@ -29,7 +29,9 @@ parser.add_argument(
help="Path to dataset config file.",
)
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None)
parser.add_argument(
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
@ -41,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
)
# compute speaker embeddings

View File

@ -51,7 +51,7 @@ def main():
N = 0
for item in tqdm(dataset_items):
# compute features
wav = ap.load_wav(item if isinstance(item, str) else item[1])
wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav)
@ -59,13 +59,13 @@ def main():
N += mel.shape[1]
mel_sum += mel.sum(1)
linear_sum += linear.sum(1)
mel_square_sum += (mel ** 2).sum(axis=1)
linear_square_sum += (linear ** 2).sum(axis=1)
mel_square_sum += (mel**2).sum(axis=1)
linear_square_sum += (linear**2).sum(axis=1)
mel_mean = mel_sum / N
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
output_file_path = args.out_path
stats = {}

View File

@ -24,6 +24,7 @@ def main():
# load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
items = train_items + eval_items
texts = "".join(item[0] for item in items)

View File

@ -43,6 +43,11 @@ def main():
items = train_items + eval_items
print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = []
for ph in phonemes:

View File

@ -1,4 +1,5 @@
import os
import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config

View File

@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
sample = {
"mel": mel,
"item_idx": self.items[idx][1],
"item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name,
}
return sample
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
path_ = i[1]
speaker_ = i[2]
path_ = i["audio_file"]
speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:

View File

@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)

View File

@ -113,7 +113,7 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4)
clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample(
self.noise_list[noise_type],
@ -135,7 +135,7 @@ class AugmentWAV(object):
self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"],
)
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4)
noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None:
@ -154,7 +154,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2))
rir = rir / np.sqrt(np.sum(rir**2))
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio):

View File

@ -75,14 +75,14 @@ def load_tts_samples(
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train]
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval]
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
@ -91,12 +91,12 @@ def load_tts_samples(
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)
attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].update({"alignment_file": attn_file})
if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip()
meta_data_eval_all[idx].append(attn_file)
attn_file = meta_data[ins["audio_file"]].strip()
meta_data_eval_all[idx].update({"alignment_file": attn_file})
# set none for the next iter
formatter = None
return meta_data_train_all, meta_data_eval_all

View File

@ -21,7 +21,7 @@ class TTSDataset(Dataset):
text_cleaner: list,
compute_linear_spec: bool,
ap: AudioProcessor,
meta_data: List[List],
meta_data: List[Dict],
compute_f0: bool = False,
f0_cache_path: str = None,
characters: Dict = None,
@ -54,7 +54,7 @@ class TTSDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances.
meta_data (list): List of dataset samples.
compute_f0 (bool): compute f0 if True. Defaults to False.
@ -199,15 +199,9 @@ class TTSDataset(Dataset):
def load_data(self, idx):
item = self.items[idx]
raw_text = item["text"]
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
# apply noise for augmentation
if self.use_noise_augment:
@ -216,12 +210,12 @@ class TTSDataset(Dataset):
if not self.input_seq_computed:
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
wav_file,
text,
item["audio_file"],
item["text"],
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
language_name if language_name else self.phoneme_language,
item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
@ -229,7 +223,7 @@ class TTSDataset(Dataset):
else:
text = np.asarray(
text_to_sequence(
text,
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
@ -238,11 +232,12 @@ class TTSDataset(Dataset):
dtype=np.int32,
)
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
if "attn_file" in locals():
attn = np.load(attn_file)
attn = None
if "alignment_file" in item:
attn = np.load(item["alignment_file"])
if len(text) > self.max_seq_len:
# return a different sample if the phonemized
@ -252,7 +247,7 @@ class TTSDataset(Dataset):
pitch = None
if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = {
@ -261,10 +256,10 @@ class TTSDataset(Dataset):
"wav": wav,
"pitch": pitch,
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
"language_name": language_name,
"wav_file_name": os.path.basename(wav_file),
"item_idx": item["audio_file"],
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"wav_file_name": os.path.basename(item["audio_file"]),
}
return sample
@ -272,11 +267,10 @@ class TTSDataset(Dataset):
def _phoneme_worker(args):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
func_args[3] = (
item[3] if item[3] else func_args[3]
item["language"] if "language" in item and item["language"] else func_args[3]
) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
return phonemes
def compute_input_seq(self, num_workers=0):
@ -286,10 +280,9 @@ class TTSDataset(Dataset):
if self.verbose:
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(
text_to_sequence(
text,
item["text"],
[self.cleaners],
custom_symbols=self.custom_symbols,
tp=self.characters,
@ -337,10 +330,10 @@ class TTSDataset(Dataset):
if by_audio_len:
lengths = []
for item in self.items:
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths)
else:
lengths = np.array([len(ins[0]) for ins in self.items])
lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths)
new_items = []
@ -555,7 +548,7 @@ class PitchExtractor:
def __init__(
self,
items: List[List],
items: List[Dict],
verbose=False,
):
self.items = items
@ -614,10 +607,9 @@ class PitchExtractor:
item = args[0]
ap = args[1]
cache_path = args[2]
_, wav_file, *_ = item
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
return pitch
return None

View File

@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -133,7 +133,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"])
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"})
return items
@ -150,7 +150,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -165,7 +165,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -179,7 +179,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[1]
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -193,7 +193,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items
@ -213,7 +213,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers:
continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
return items
@ -240,7 +240,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.append([text, wav_file, "LTTS_" + speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items
@ -259,7 +259,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file)
continue
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items
@ -281,7 +281,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append([text, wav_file, speaker_id])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
return items
@ -299,7 +299,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
return items
@ -334,7 +334,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker in ignored_speakers:
continue
items.append([text, wav_file, "MLS_" + speaker])
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
return items
@ -404,7 +404,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name])
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
return items
@ -418,5 +418,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "")
items.append([text, wav_file, speaker_name])
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items

View File

@ -113,7 +113,7 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m ** 2)
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)

View File

@ -65,7 +65,7 @@ class WN(torch.nn.Module):
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate ** i
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding

View File

@ -101,7 +101,7 @@ class Encoder(nn.Module):
self.encoder_type = encoder_type
# embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder module
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:

View File

@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# relative positional encoding layers
if rel_attn_window_size is not None:
n_heads_rel = 1 if heads_share else num_heads
rel_stddev = self.k_channels ** -0.5
rel_stddev = self.k_channels**-0.5
emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
)
@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]

View File

@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
@ -665,7 +665,7 @@ class VitsDiscriminatorLoss(nn.Module):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2)
fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())

View File

@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module):
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3)
# out = score * V

View File

@ -57,7 +57,7 @@ class TextEncoder(nn.Module):
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
if language_emb_dim:
hidden_channels += language_emb_dim

View File

@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module):
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(num_layers):
dilation = kernel_size ** i
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module):
# posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q
)
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module):
z = torch.flip(z, [1])
# flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows))

View File

@ -206,9 +206,9 @@ class GlowTTS(BaseTTS):
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
@ -255,9 +255,9 @@ class GlowTTS(BaseTTS):
# find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

View File

@ -4,7 +4,6 @@ from itertools import chain
from typing import Dict, List, Tuple
import torch
import torchaudio
from coqpit import Coqpit
from torch import nn
@ -424,9 +423,9 @@ class Vits(BaseTTS):
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
)
else:
self.audio_transform = None
@ -591,9 +590,9 @@ class Vits(BaseTTS):
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
@ -692,10 +691,17 @@ class Vits(BaseTTS):
if self.args.use_sdp:
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
x,
x_mask,
g=g if self.args.condition_dp_on_speaker else None,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
)
else:
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)

View File

@ -113,7 +113,7 @@ def _set_file_path(path):
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])

View File

@ -118,7 +118,7 @@ class SpeakerManager:
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item[2] for item in items})
speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
@ -414,7 +414,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items])
speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])

View File

@ -8,7 +8,7 @@ from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)])
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
return gauss / gauss.sum()
@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

View File

@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None:
S = S ** self.power
S = S**self.power
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
@ -634,8 +634,8 @@ class AudioProcessor(object):
S = self._db_to_amp(S)
# Reconstruct phase
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
@ -643,8 +643,8 @@ class AudioProcessor(object):
S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -781,7 +781,7 @@ class AudioProcessor(object):
@staticmethod
def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2))
a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
@ -853,7 +853,7 @@ class AudioProcessor(object):
@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2 ** qc - 1
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
@ -865,13 +865,13 @@ class AudioProcessor(object):
@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2 ** qc - 1
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
@ -884,12 +884,12 @@ class AudioProcessor(object):
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2 ** bits - 1) / 2
return (x + 1.0) * (2**bits - 1) / 2
@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2 ** bits - 1) - 1
return 2 * x / (2**bits - 1) - 1
def _log(x, base):

View File

@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
chunk = file_obj.read(1024**2)
if not chunk:
break
hash_func.update(chunk)

View File

@ -39,7 +39,7 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5)
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
@ -63,7 +63,7 @@ def lr_decay(init_lr, global_step, warmup_steps):
It is only being used by the Speaker Encoder trainer."""
warmup_steps = float(warmup_steps)
step = global_step + 1.0
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5)
return lr

View File

@ -127,5 +127,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
lr_scheduler_disc_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
)
scheduler_after_epoch: bool = False

View File

@ -111,7 +111,7 @@ class WaveRNNDataset(Dataset):
elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse

View File

@ -126,9 +126,9 @@ class LVCBlock(torch.nn.Module):
)
for i in range(conv_layers):
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
padding = (3**i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d(
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
)
self.convs.append(conv)

View File

@ -12,7 +12,7 @@ class ResidualStack(nn.Module):
self.blocks = nn.ModuleList()
for idx in range(num_res_blocks):
layer_kernel_size = kernel_size
layer_dilation = layer_kernel_size ** idx
layer_dilation = layer_kernel_size**idx
layer_padding = base_padding * layer_dilation
self.blocks += [
nn.Sequential(

View File

@ -72,6 +72,6 @@ class ResidualBlock(torch.nn.Module):
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
x = (self.conv1x1_out(x) + residual) * (0.5**2)
return x, s

View File

@ -207,7 +207,7 @@ class HifiganGenerator(torch.nn.Module):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,

View File

@ -36,7 +36,7 @@ class MelganGenerator(nn.Module):
# upsampling layers and residual stacks
for idx, upsample_factor in enumerate(upsample_factors):
layer_in_channels = base_channels // (2 ** idx)
layer_in_channels = base_channels // (2**idx)
layer_out_channels = base_channels // (2 ** (idx + 1))
layer_filter_size = upsample_factor * 2
layer_stride = upsample_factor

View File

@ -35,7 +35,7 @@ class ParallelWaveganDiscriminator(nn.Module):
if i == 0:
dilation = 1
else:
dilation = i if dilation_factor == 1 else dilation_factor ** i
dilation = i if dilation_factor == 1 else dilation_factor**i
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [

View File

@ -142,7 +142,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x):
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]

View File

@ -130,7 +130,7 @@ class UnivnetGenerator(torch.nn.Module):
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x):
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]

View File

@ -153,7 +153,7 @@ class Wavegrad(BaseVocoder):
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, beta):
@ -161,8 +161,8 @@ class Wavegrad(BaseVocoder):
self.num_steps = len(beta)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
noise_level = alpha_hat ** 0.5
noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
noise_level = alpha_hat**0.5
# pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32))
@ -170,7 +170,7 @@ class Wavegrad(BaseVocoder):
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha ** 0.5
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5

View File

@ -225,7 +225,7 @@ class Wavernn(BaseVocoder):
super().__init__(config)
if isinstance(self.args.mode, int):
self.n_classes = 2 ** self.args.mode
self.n_classes = 2**self.args.mode
elif self.args.mode == "mold":
self.n_classes = 3 * 10
elif self.args.mode == "gauss":

View File

@ -5,13 +5,13 @@ from tests import get_tests_input_path
from TTS.tts.datasets.formatters import common_voice
class TestPreprocessors(unittest.TestCase):
class TestTTSFormatters(unittest.TestCase):
def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
root_path = get_tests_input_path()
meta_file = "common_voice.tsv"
items = common_voice(root_path, meta_file)
assert items[0][0] == "The applicants are invited for coffee and visa is given immediately."
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately."
assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts."
assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")
assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")

View File

@ -46,6 +46,6 @@ def test_wavernn():
config.model_args.mode = 4
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape
assert np.all(output.shape == (2, 1280, 2**4)), output.shape
output = model.inference(dummy_y, True, 5500, 550)
assert np.all(output.shape == (256 * (y_size - 1),))