mirror of https://github.com/coqui-ai/TTS.git
make style
parent
e1bdeacd2e
commit
3e9ca4b95d
|
@ -1,14 +1,15 @@
|
|||
"""Find all the unique characters in a dataset"""
|
||||
import argparse
|
||||
import multiprocessing
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
import numpy
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
|
||||
import numpy
|
||||
import multiprocessing
|
||||
from TTS.tts.utils.text import text2phone
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
|
||||
def compute_phonemes(item):
|
||||
try:
|
||||
|
@ -19,6 +20,7 @@ def compute_phonemes(item):
|
|||
return []
|
||||
return list(set(ph))
|
||||
|
||||
|
||||
def main():
|
||||
global c
|
||||
# pylint: disable=bad-option-value
|
||||
|
@ -51,8 +53,6 @@ def main():
|
|||
phones_force_lower = [c.lower() for c in phones]
|
||||
phones_force_lower = set(phones_force_lower)
|
||||
|
||||
|
||||
|
||||
print(f" > Number of unique phonemes: {len(phones)}")
|
||||
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||
import os
|
||||
import tqdm
|
||||
import glob
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import wave
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import webrtcvad
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
import multiprocessing
|
||||
from itertools import chain
|
||||
|
||||
|
||||
def read_wave(path):
|
||||
"""Reads a .wav file.
|
||||
|
||||
Takes the path, and returns (PCM audio data, sample rate).
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'rb')) as wf:
|
||||
with contextlib.closing(wave.open(path, "rb")) as wf:
|
||||
num_channels = wf.getnchannels()
|
||||
assert num_channels == 1
|
||||
sample_width = wf.getsampwidth()
|
||||
|
@ -36,7 +37,7 @@ def write_wave(path, audio, sample_rate):
|
|||
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
with contextlib.closing(wave.open(path, "wb")) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
|
@ -45,6 +46,7 @@ def write_wave(path, audio, sample_rate):
|
|||
|
||||
class Frame(object):
|
||||
"""Represents a "frame" of audio data."""
|
||||
|
||||
def __init__(self, bytes, timestamp, duration):
|
||||
self.bytes = bytes
|
||||
self.timestamp = timestamp
|
||||
|
@ -64,13 +66,12 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
|
|||
timestamp = 0.0
|
||||
duration = (float(n) / sample_rate) / 2.0
|
||||
while offset + n < len(audio):
|
||||
yield Frame(audio[offset:offset + n], timestamp, duration)
|
||||
yield Frame(audio[offset : offset + n], timestamp, duration)
|
||||
timestamp += duration
|
||||
offset += n
|
||||
|
||||
|
||||
def vad_collector(sample_rate, frame_duration_ms,
|
||||
padding_duration_ms, vad, frames):
|
||||
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
||||
"""Filters out non-voiced audio frames.
|
||||
|
||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||
|
@ -133,25 +134,26 @@ def vad_collector(sample_rate, frame_duration_ms,
|
|||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||
# audio we've collected.
|
||||
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||
#sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||
triggered = False
|
||||
yield b''.join([f.bytes for f in voiced_frames])
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
# If we have any leftover voiced audio when we run out of input,
|
||||
# yield it.
|
||||
if voiced_frames:
|
||||
yield b''.join([f.bytes for f in voiced_frames])
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
|
||||
|
||||
def remove_silence(filepath):
|
||||
filename = os.path.basename(filepath)
|
||||
output_path = filepath.replace(os.path.join(args.input_dir, ''),os.path.join(args.output_dir, ''))
|
||||
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
# ignore if the file exists
|
||||
if os.path.exists(output_path) and not args.force:
|
||||
return False
|
||||
# create all directory structure
|
||||
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
padding_duration_ms = 300 # default 300
|
||||
padding_duration_ms = 300 # default 300
|
||||
audio, sample_rate = read_wave(filepath)
|
||||
vad = webrtcvad.Vad(int(args.aggressiveness))
|
||||
frames = frame_generator(30, audio, sample_rate)
|
||||
|
@ -180,6 +182,7 @@ def remove_silence(filepath):
|
|||
# if fail to remove silence just write the file
|
||||
write_wave(output_path, audio, sample_rate)
|
||||
|
||||
|
||||
def preprocess_audios():
|
||||
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
||||
print("> Number of files: ", len(files))
|
||||
|
@ -193,21 +196,31 @@ def preprocess_audios():
|
|||
else:
|
||||
print("> No files Found !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
usage
|
||||
python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_dir', type=str, default='../VCTK-Corpus',
|
||||
help='Dataset root dir')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='../VCTK-Corpus-removed-silence',
|
||||
help='Output Dataset dir')
|
||||
parser.add_argument('-f', '--force', type=bool, default=True,
|
||||
help='Force the replace of exists files')
|
||||
parser.add_argument('-g', '--glob', type=str, default='**/*.wav',
|
||||
help='path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav')
|
||||
parser.add_argument('-a', '--aggressiveness', type=int, default=2,
|
||||
help='set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.')
|
||||
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||
)
|
||||
parser.add_argument("-f", "--force", type=bool, default=True, help="Force the replace of exists files")
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--glob",
|
||||
type=str,
|
||||
default="**/*.wav",
|
||||
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--aggressiveness",
|
||||
type=int,
|
||||
default=2,
|
||||
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
preprocess_audios()
|
||||
|
|
|
@ -5,20 +5,20 @@ import torch.nn as nn
|
|||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class PreEmphasis(torch.nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer(
|
||||
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
|
@ -110,8 +110,15 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
|
||||
)
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
|
@ -213,7 +220,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config['hop_length']
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
|
|
|
@ -179,10 +179,12 @@ def setup_model(c):
|
|||
c.model_params["num_lstm_layers"],
|
||||
)
|
||||
elif c.model_params["model_name"].lower() == "resnet":
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
|
||||
model = ResNetSpeakerEncoder(
|
||||
input_dim=c.model_params["input_dim"],
|
||||
proj_dim=c.model_params["proj_dim"],
|
||||
log_input=c.model_params.get("log_input", False),
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
audio_config=c.audio
|
||||
audio_config=c.audio,
|
||||
)
|
||||
return model
|
||||
|
||||
|
|
|
@ -265,7 +265,9 @@ class Trainer:
|
|||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||
# save speakers json
|
||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
|
||||
self.model.language_manager.save_language_ids_to_file(
|
||||
os.path.join(self.output_path, "language_ids.json")
|
||||
)
|
||||
if hasattr(self.config, "model_args"):
|
||||
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
||||
else:
|
||||
|
|
|
@ -542,6 +542,7 @@ class TTSDataset(Dataset):
|
|||
)
|
||||
)
|
||||
|
||||
|
||||
class PitchExtractor:
|
||||
"""Pitch Extractor for computing F0 from wav files.
|
||||
Args:
|
||||
|
|
|
@ -304,7 +304,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None):
|
|||
return items
|
||||
|
||||
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
items = []
|
||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
|
|
|
@ -602,7 +602,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
fine_tuning_mode=0,
|
||||
use_speaker_encoder_as_loss=False,
|
||||
gt_spk_emb=None,
|
||||
syn_spk_emb=None
|
||||
syn_spk_emb=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -638,7 +638,9 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
|
||||
if use_speaker_encoder_as_loss:
|
||||
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
||||
loss_se = (
|
||||
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
||||
)
|
||||
loss += loss_se
|
||||
return_dict["loss_spk_encoder"] = loss_se
|
||||
|
||||
|
|
|
@ -178,7 +178,14 @@ class StochasticDurationPredictor(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
dropout_p: float,
|
||||
num_flows=4,
|
||||
cond_channels=0,
|
||||
language_emb_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -246,7 +246,9 @@ class BaseTTS(BaseModel):
|
|||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
speaker_id_mapping = (
|
||||
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||
else:
|
||||
|
@ -262,7 +264,9 @@ class BaseTTS(BaseModel):
|
|||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
if hasattr(self, "language_manager"):
|
||||
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
|
|
|
@ -229,7 +229,6 @@ class VitsArgs(Coqpit):
|
|||
freeze_waveform_decoder: bool = False
|
||||
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
"""VITS TTS model
|
||||
|
||||
|
@ -306,7 +305,7 @@ class Vits(BaseTTS):
|
|||
args.num_layers_text_encoder,
|
||||
args.kernel_size_text_encoder,
|
||||
args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
|
@ -389,16 +388,26 @@ class Vits(BaseTTS):
|
|||
# TODO: make this a function
|
||||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
||||
raise RuntimeError(
|
||||
" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||
)
|
||||
self.speaker_manager.init_speaker_encoder(
|
||||
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
||||
)
|
||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||
for param in self.speaker_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
||||
if (
|
||||
hasattr(self.speaker_encoder, "audio_config")
|
||||
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
||||
):
|
||||
self.audio_transform = torchaudio.transforms.Resample(
|
||||
orig_freq=self.audio_config["sample_rate"],
|
||||
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
||||
)
|
||||
else:
|
||||
self.audio_transform = None
|
||||
else:
|
||||
|
@ -529,7 +538,13 @@ class Vits(BaseTTS):
|
|||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
||||
return {
|
||||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -567,7 +582,7 @@ class Vits(BaseTTS):
|
|||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb=None
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
|
@ -621,9 +636,9 @@ class Vits(BaseTTS):
|
|||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
waveform.transpose(1, 2),
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||
|
@ -653,7 +668,7 @@ class Vits(BaseTTS):
|
|||
"logs_q": logs_q,
|
||||
"waveform_seg": wav_seg,
|
||||
"gt_spk_emb": gt_spk_emb,
|
||||
"syn_spk_emb": syn_spk_emb
|
||||
"syn_spk_emb": syn_spk_emb,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -695,7 +710,7 @@ class Vits(BaseTTS):
|
|||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb=None
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
|
@ -737,9 +752,9 @@ class Vits(BaseTTS):
|
|||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform.transpose(1, 2),
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
waveform.transpose(1, 2),
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||
|
@ -770,7 +785,7 @@ class Vits(BaseTTS):
|
|||
"logs_q": logs_q,
|
||||
"waveform_seg": wav_seg,
|
||||
"gt_spk_emb": gt_spk_emb,
|
||||
"syn_spk_emb": syn_spk_emb
|
||||
"syn_spk_emb": syn_spk_emb,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -790,14 +805,16 @@ class Vits(BaseTTS):
|
|||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
# language embedding
|
||||
lang_emb=None
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb)
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||
|
||||
|
@ -866,7 +883,7 @@ class Vits(BaseTTS):
|
|||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if hasattr(self, 'emb_l'):
|
||||
if hasattr(self, "emb_l"):
|
||||
for param in self.emb_l.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
@ -932,7 +949,7 @@ class Vits(BaseTTS):
|
|||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion[optimizer_idx](
|
||||
waveform_hat=outputs["model_outputs"].float(),
|
||||
waveform= outputs["waveform_seg"].float(),
|
||||
waveform=outputs["waveform_seg"].float(),
|
||||
z_p=outputs["z_p"].float(),
|
||||
logs_q=outputs["logs_q"].float(),
|
||||
m_p=outputs["m_p"].float(),
|
||||
|
@ -945,7 +962,7 @@ class Vits(BaseTTS):
|
|||
fine_tuning_mode=self.args.fine_tuning_mode,
|
||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||
gt_spk_emb=outputs["gt_spk_emb"],
|
||||
syn_spk_emb=outputs["syn_spk_emb"]
|
||||
syn_spk_emb=outputs["syn_spk_emb"],
|
||||
)
|
||||
# ignore duration loss if fine tuning mode is on
|
||||
if not self.args.fine_tuning_mode:
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, List
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
|
||||
|
||||
class LanguageManager:
|
||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by language.
|
||||
|
@ -20,7 +21,9 @@ class LanguageManager:
|
|||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
||||
>>> language_id_mapper = manager.language_ids
|
||||
"""
|
||||
|
||||
language_id_mapping: Dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
language_id_file_path: str = "",
|
||||
|
@ -85,6 +88,7 @@ class LanguageManager:
|
|||
"""
|
||||
self._save_json(file_path, self.language_id_mapping)
|
||||
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the language_ids.json under the given path or the above it.
|
||||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
|
@ -97,6 +101,7 @@ def _set_file_path(path):
|
|||
return path_continue
|
||||
return None
|
||||
|
||||
|
||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
||||
"""Initiate a `LanguageManager` instance by the provided config.
|
||||
|
||||
|
@ -118,7 +123,7 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
|
|||
# restoring language manager from a previous run.
|
||||
if language_file:
|
||||
language_manager.set_language_ids_from_file(language_file)
|
||||
if language_manager.num_languages > 0:
|
||||
if language_manager.num_languages > 0:
|
||||
print(
|
||||
" > Language manager is loaded with {} languages: {}".format(
|
||||
language_manager.num_languages, ", ".join(language_manager.language_names)
|
||||
|
@ -126,11 +131,12 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
|
|||
)
|
||||
return language_manager
|
||||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
weight_language = 1. / language_count
|
||||
weight_language = 1.0 / language_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||
|
|
|
@ -432,11 +432,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||
return speaker_manager
|
||||
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
speaker_names = np.array([item[2] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
weight_speaker = 1. / speaker_count
|
||||
weight_speaker = 1.0 / speaker_count
|
||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
|
@ -136,8 +136,9 @@ def phoneme_cleaners(text):
|
|||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def multilingual_cleaners(text):
|
||||
'''Pipeline for multilingual text'''
|
||||
"""Pipeline for multilingual text"""
|
||||
text = lowercase(text)
|
||||
text = replace_symbols(text, lang=None)
|
||||
text = remove_aux_symbols(text)
|
||||
|
|
|
@ -3,19 +3,27 @@ import os
|
|||
import shutil
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
dataset_config1 = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en"
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en",
|
||||
)
|
||||
|
||||
dataset_config2 = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en2"
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
path="tests/data/ljspeech",
|
||||
language="en2",
|
||||
)
|
||||
|
||||
config = VitsConfig(
|
||||
|
|
Loading…
Reference in New Issue