make style

pull/1032/head
WeberJulian 2021-11-02 17:31:14 +01:00 committed by Eren Gölge
parent e1bdeacd2e
commit 3e9ca4b95d
15 changed files with 158 additions and 87 deletions

View File

@ -1,14 +1,15 @@
"""Find all the unique characters in a dataset"""
import argparse
import multiprocessing
from argparse import RawTextHelpFormatter
import numpy
from tqdm.contrib.concurrent import process_map
from TTS.config import load_config
from TTS.tts.datasets import load_meta_data
import numpy
import multiprocessing
from TTS.tts.utils.text import text2phone
from tqdm.contrib.concurrent import process_map
def compute_phonemes(item):
try:
@ -19,6 +20,7 @@ def compute_phonemes(item):
return []
return list(set(ph))
def main():
global c
# pylint: disable=bad-option-value
@ -51,8 +53,6 @@ def main():
phones_force_lower = [c.lower() for c in phones]
phones_force_lower = set(phones_force_lower)
print(f" > Number of unique phonemes: {len(phones)}")
print(f" > Unique phonemes: {''.join(sorted(phones))}")
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")

View File

@ -1,26 +1,27 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import os
import tqdm
import glob
import argparse
import pathlib
import collections
import contextlib
import glob
import multiprocessing
import os
import pathlib
import sys
import wave
from itertools import chain
import numpy as np
import tqdm
import webrtcvad
from tqdm.contrib.concurrent import process_map
import multiprocessing
from itertools import chain
def read_wave(path):
"""Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate).
"""
with contextlib.closing(wave.open(path, 'rb')) as wf:
with contextlib.closing(wave.open(path, "rb")) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
@ -36,7 +37,7 @@ def write_wave(path, audio, sample_rate):
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
with contextlib.closing(wave.open(path, "wb")) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
@ -45,6 +46,7 @@ def write_wave(path, audio, sample_rate):
class Frame(object):
"""Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration):
self.bytes = bytes
self.timestamp = timestamp
@ -64,13 +66,12 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
yield Frame(audio[offset : offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only
@ -133,25 +134,26 @@ def vad_collector(sample_rate, frame_duration_ms,
# unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen:
#sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
triggered = False
yield b''.join([f.bytes for f in voiced_frames])
yield b"".join([f.bytes for f in voiced_frames])
ring_buffer.clear()
voiced_frames = []
# If we have any leftover voiced audio when we run out of input,
# yield it.
if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames])
yield b"".join([f.bytes for f in voiced_frames])
def remove_silence(filepath):
filename = os.path.basename(filepath)
output_path = filepath.replace(os.path.join(args.input_dir, ''),os.path.join(args.output_dir, ''))
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
# ignore if the file exists
if os.path.exists(output_path) and not args.force:
return False
# create all directory structure
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
padding_duration_ms = 300 # default 300
padding_duration_ms = 300 # default 300
audio, sample_rate = read_wave(filepath)
vad = webrtcvad.Vad(int(args.aggressiveness))
frames = frame_generator(30, audio, sample_rate)
@ -180,6 +182,7 @@ def remove_silence(filepath):
# if fail to remove silence just write the file
write_wave(output_path, audio, sample_rate)
def preprocess_audios():
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
print("> Number of files: ", len(files))
@ -193,21 +196,31 @@ def preprocess_audios():
else:
print("> No files Found !")
if __name__ == "__main__":
"""
usage
python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2
"""
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../VCTK-Corpus',
help='Dataset root dir')
parser.add_argument('-o', '--output_dir', type=str, default='../VCTK-Corpus-removed-silence',
help='Output Dataset dir')
parser.add_argument('-f', '--force', type=bool, default=True,
help='Force the replace of exists files')
parser.add_argument('-g', '--glob', type=str, default='**/*.wav',
help='path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav')
parser.add_argument('-a', '--aggressiveness', type=int, default=2,
help='set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.')
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
parser.add_argument(
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
)
parser.add_argument("-f", "--force", type=bool, default=True, help="Force the replace of exists files")
parser.add_argument(
"-g",
"--glob",
type=str,
default="**/*.wav",
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
)
parser.add_argument(
"-a",
"--aggressiveness",
type=int,
default=2,
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
)
args = parser.parse_args()
preprocess_audios()

View File

@ -5,20 +5,20 @@ import torch.nn as nn
from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer(
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
)
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@ -110,8 +110,15 @@ class ResNetSpeakerEncoder(nn.Module):
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
)
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
@ -213,7 +220,7 @@ class ResNetSpeakerEncoder(nn.Module):
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config['hop_length']
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]

View File

@ -179,10 +179,12 @@ def setup_model(c):
c.model_params["num_lstm_layers"],
)
elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
model = ResNetSpeakerEncoder(
input_dim=c.model_params["input_dim"],
proj_dim=c.model_params["proj_dim"],
log_input=c.model_params.get("log_input", False),
use_torch_spec=c.model_params.get("use_torch_spec", False),
audio_config=c.audio
audio_config=c.audio,
)
return model

View File

@ -265,7 +265,9 @@ class Trainer:
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
self.model.language_manager.save_language_ids_to_file(
os.path.join(self.output_path, "language_ids.json")
)
if hasattr(self.config, "model_args"):
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
else:

View File

@ -542,6 +542,7 @@ class TTSDataset(Dataset):
)
)
class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files.
Args:

View File

@ -304,7 +304,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None):
return items
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)

View File

@ -602,7 +602,7 @@ class VitsGeneratorLoss(nn.Module):
fine_tuning_mode=0,
use_speaker_encoder_as_loss=False,
gt_spk_emb=None,
syn_spk_emb=None
syn_spk_emb=None,
):
"""
Shapes:
@ -638,7 +638,9 @@ class VitsGeneratorLoss(nn.Module):
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss:
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
loss_se = (
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
)
loss += loss_se
return_dict["loss_spk_encoder"] = loss_se

View File

@ -178,7 +178,14 @@ class StochasticDurationPredictor(nn.Module):
"""
def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
dropout_p: float,
num_flows=4,
cond_channels=0,
language_emb_dim=None,
):
super().__init__()

View File

@ -246,7 +246,9 @@ class BaseTTS(BaseModel):
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
@ -262,7 +264,9 @@ class BaseTTS(BaseModel):
custom_symbols = self.make_symbols(self.config)
if hasattr(self, "language_manager"):
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
else:
language_id_mapping = None

View File

@ -229,7 +229,6 @@ class VitsArgs(Coqpit):
freeze_waveform_decoder: bool = False
class Vits(BaseTTS):
"""VITS TTS model
@ -306,7 +305,7 @@ class Vits(BaseTTS):
args.num_layers_text_encoder,
args.kernel_size_text_encoder,
args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim
language_emb_dim=self.embedded_language_dim,
)
self.posterior_encoder = PosteriorEncoder(
@ -389,16 +388,26 @@ class Vits(BaseTTS):
# TODO: make this a function
if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
raise RuntimeError(
" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
self.speaker_manager.init_speaker_encoder(
config.speaker_encoder_model_path, config.speaker_encoder_config_path
)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
if (
hasattr(self.speaker_encoder, "audio_config")
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_encoder.audio_config["sample_rate"],
)
else:
self.audio_transform = None
else:
@ -529,7 +538,13 @@ class Vits(BaseTTS):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def forward(
self,
@ -567,7 +582,7 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
lang_emb=None
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
@ -621,9 +636,9 @@ class Vits(BaseTTS):
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
@ -653,7 +668,7 @@ class Vits(BaseTTS):
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb
"syn_spk_emb": syn_spk_emb,
}
)
return outputs
@ -695,7 +710,7 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding
lang_emb=None
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
@ -737,9 +752,9 @@ class Vits(BaseTTS):
o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment(
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
@ -770,7 +785,7 @@ class Vits(BaseTTS):
"logs_q": logs_q,
"waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb
"syn_spk_emb": syn_spk_emb,
}
)
return outputs
@ -790,14 +805,16 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1)
# language embedding
lang_emb=None
lang_emb = None
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb)
logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
)
else:
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
@ -866,7 +883,7 @@ class Vits(BaseTTS):
for param in self.text_encoder.parameters():
param.requires_grad = False
if hasattr(self, 'emb_l'):
if hasattr(self, "emb_l"):
for param in self.emb_l.parameters():
param.requires_grad = False
@ -932,7 +949,7 @@ class Vits(BaseTTS):
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(),
waveform= outputs["waveform_seg"].float(),
waveform=outputs["waveform_seg"].float(),
z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(),
@ -945,7 +962,7 @@ class Vits(BaseTTS):
fine_tuning_mode=self.args.fine_tuning_mode,
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=outputs["gt_spk_emb"],
syn_spk_emb=outputs["syn_spk_emb"]
syn_spk_emb=outputs["syn_spk_emb"],
)
# ignore duration loss if fine tuning mode is on
if not self.args.fine_tuning_mode:

View File

@ -1,13 +1,14 @@
import os
import json
import torch
import os
from typing import Dict, List, Tuple
import fsspec
import numpy as np
from typing import Dict, Tuple, List
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language.
@ -20,7 +21,9 @@ class LanguageManager:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
>>> language_id_mapper = manager.language_ids
"""
language_id_mapping: Dict = {}
def __init__(
self,
language_id_file_path: str = "",
@ -85,6 +88,7 @@ class LanguageManager:
"""
self._save_json(file_path, self.language_id_mapping)
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.
Intended to band aid the different paths returned in restored and continued training."""
@ -97,6 +101,7 @@ def _set_file_path(path):
return path_continue
return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config.
@ -118,7 +123,7 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
# restoring language manager from a previous run.
if language_file:
language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0:
if language_manager.num_languages > 0:
print(
" > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names)
@ -126,11 +131,12 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
)
return language_manager
def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1. / language_count
weight_language = 1.0 / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -432,11 +432,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager
def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1. / speaker_count
weight_speaker = 1.0 / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -136,8 +136,9 @@ def phoneme_cleaners(text):
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text):
'''Pipeline for multilingual text'''
"""Pipeline for multilingual text"""
text = lowercase(text)
text = replace_symbols(text, lang=None)
text = remove_aux_symbols(text)

View File

@ -3,19 +3,27 @@ import os
import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
dataset_config1 = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en"
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
)
dataset_config2 = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en2"
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en2",
)
config = VitsConfig(