REBASED: Transform Speaker Encoder in a Generic Encoder and Implement Emotion Encoder training support (#1349)

* Rename Speaker encoder module to encoder

* Add a generic emotion dataset formatter

* Transform the Speaker Encoder dataset to a generic dataset and create emotion encoder config

* Add class map in emotion config

* Add Base encoder config

* Add evaluation encoder script

* Fix the bug in plot_embeddings

* Enable Weight decay for encoder training

* Add argumnet to disable storage

* Add Perfect Sampler and remove storage

* Add evaluation during encoder training

* Fix lint checks

* Remove useless config parameter

* Active evaluation in speaker encoder test and use multispeaker dataset for this test

* Unit tests fixs

* Remove useless tests for speedup the aux_tests

* Use get_optimizer in Encoder

* Add BaseEncoder Class

* Fix the unitests

* Add Perfect Batch Sampler unit test

* Add compute encoder accuracy in a function
pull/1395/head
Edresson Casanova 2022-03-11 10:43:40 -03:00 committed by GitHub
parent 36e9ea2f97
commit f81892483d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 962 additions and 2791 deletions

View File

@ -42,33 +42,35 @@ c_dataset = load_config(args.config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager(
encoder_manager = SpeakerManager(
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
)
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
# compute speaker embeddings
speaker_mapping = {}
for idx, wav_file in enumerate(tqdm(wav_files)):
if isinstance(wav_file, list):
speaker_name = wav_file[2]
wav_file = wav_file[1]
if isinstance(wav_file, dict):
class_name = wav_file[class_name_key]
wav_file = wav_file["audio_file"]
else:
speaker_name = None
class_name = None
wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
else:
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
# create speaker_mapping if target dataset is defined
speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["name"] = class_name
speaker_mapping[wav_file_name]["embedding"] = embedd
if speaker_mapping:
@ -81,5 +83,5 @@ if speaker_mapping:
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
# pylint: disable=W0212
speaker_manager._save_json(mapping_file_path, speaker_mapping)
encoder_manager._save_json(mapping_file_path, speaker_mapping)
print("Speaker embeddings saved at:", mapping_file_path)

88
TTS/bin/eval_encoder.py Normal file
View File

@ -0,0 +1,88 @@
import argparse
import torch
from argparse import RawTextHelpFormatter
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager
def compute_encoder_accuracy(dataset_items, encoder_manager):
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
class_acc_dict = {}
# compute embeddings for all wav_files
for item in tqdm(dataset_items):
class_name = item[class_name_key]
wav_file = item["audio_file"]
# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda:
embedding = embedding.cuda()
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
predicted_label = map_classid_to_classname[str(class_id)]
else:
predicted_label = None
if class_name is not None and predicted_label is not None:
is_equal = int(class_name == predicted_label)
if class_name not in class_acc_dict:
class_acc_dict[class_name] = [is_equal]
else:
class_acc_dict[class_name].append(is_equal)
else:
raise RuntimeError("Error: class_name or/and predicted_label are None")
acc_avg = 0
for key, values in class_acc_dict.items():
acc = sum(values)/len(values)
print("Class", key, "Accuracy:", acc)
acc_avg += acc
print("Average Accuracy:", acc_avg/len(class_acc_dict))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Compute the accuracy of the encoder.\n\n"""
"""
Example runs:
python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
parser.add_argument(
"config_path",
type=str,
help="Path to model config file.",
)
parser.add_argument(
"config_dataset_path",
type=str,
help="Path to dataset config file.",
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
args = parser.parse_args()
c_dataset = load_config(args.config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
items = meta_data_train + meta_data_eval
enc_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
)
compute_encoder_accuracy(items, enc_manager)

View File

@ -10,16 +10,16 @@ import torch
from torch.utils.data import DataLoader
from trainer.torch import NoamLR
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
from TTS.speaker_encoder.utils.training import init_training
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.io import load_fsspec
from TTS.utils.radam import RAdam
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from trainer.trainer_utils import get_optimizer
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True
@ -32,164 +32,238 @@ print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
dataset = EncoderDataset(
c,
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len,
num_utter_per_class=num_utter_per_class,
num_classes_in_batch=num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
# get classes list
classes = dataset.get_class_list()
sampler = PerfectBatchSampler(
dataset.items,
classes,
batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
num_classes_in_batch=num_classes_in_batch,
num_gpus=1,
shuffle=not is_val,
drop_last=True)
if len(classes) < num_classes_in_batch:
if is_val:
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
if is_val:
loader = None
else:
dataset = SpeakerEncoderDataset(
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len,
num_utter_per_speaker=c.num_utters_per_speaker,
num_speakers_in_batch=c.num_speakers_in_batch,
skip_speakers=c.skip_speakers,
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose,
augmentation_config=c.audio_augmentation,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
dataset.set_classes(train_classes)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=c.num_speakers_in_batch,
shuffle=False,
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn,
)
return loader, dataset.get_num_speakers()
loader = DataLoader(
dataset,
num_workers=c.num_loader_workers,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
)
return loader, classes, dataset.get_map_classid_to_classname()
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
def evaluation(model, criterion, data_loader, global_step):
eval_loss = 0
for _, data in enumerate(data_loader):
with torch.no_grad():
# setup input data
inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
eval_loss += loss.item()
eval_avg_loss = eval_loss/len(data_loader)
# save stats
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
# plot the last batch in the evaluation
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.eval_figures(global_step, figures)
return eval_avg_loss
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
model.train()
epoch_time = 0
best_loss = float("inf")
avg_loss = 0
avg_loss_all = 0
avg_loader_time = 0
end_time = time.time()
for epoch in range(c.epochs):
tot_loss = 0
epoch_time = 0
for _, data in enumerate(data_loader):
start_time = time.time()
for _, data in enumerate(data_loader):
start_time = time.time()
# setup input data
inputs, labels = data
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
# ToDo: move it to a unit test
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
# idx = 0
# for j in range(0, c.num_classes_in_batch, 1):
# for i in range(j, len(labels), c.num_classes_in_batch):
# if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
# print("Invalid")
# print(labels)
# exit()
# idx += 1
# labels = labels_converted
# inputs = inputs_converted
# setup input data
inputs, labels = data
loader_time = time.time() - end_time
global_step += 1
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
# loss computation
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
step_time = time.time() - start_time
epoch_time += step_time
step_time = time.time() - start_time
epoch_time += step_time
# Averaged Loss and Averaged Loader Time
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item()
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
avg_loader_time = (
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
if avg_loader_time != 0
else loader_time
)
current_lr = optimizer.param_groups[0]["lr"]
# acumulate the total epoch loss
tot_loss += loss.item()
if global_step % c.steps_plot_stats == 0:
# Plot Training Epoch Stats
train_stats = {
"loss": avg_loss,
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
"avg_loader_time": avg_loader_time,
}
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
# FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
}
dashboard_logger.train_figures(global_step, figures)
if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
flush=True,
# Averaged Loader Time
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
avg_loader_time = (
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
if avg_loader_time != 0
else loader_time
)
avg_loss_all += avg_loss
current_lr = optimizer.param_groups[0]["lr"]
if global_step >= c.max_train_step or global_step % c.save_step == 0:
# save best model only
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step)
avg_loss_all = 0
if global_step >= c.max_train_step:
break
if global_step % c.steps_plot_stats == 0:
# Plot Training Epoch Stats
train_stats = {
"loss": loss.item(),
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
"avg_loader_time": avg_loader_time,
}
dashboard_logger.train_epoch_stats(global_step, train_stats)
figures = {
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
}
dashboard_logger.train_figures(global_step, figures)
end_time = time.time()
if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
flush=True,
)
return avg_loss, global_step
if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
end_time = time.time()
print("")
print(
">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time
),
flush=True,
)
# evaluation
if c.run_eval:
model.eval()
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
" | > Epoch:{} AvgLoss: {:.5f} ".format(
epoch, eval_loss
),
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
model.train()
return best_loss, global_step
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train
global meta_data_eval
global train_classes
ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c)
optimizer = RAdam(model.parameters(), lr=c.lr)
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
else:
raise Exception("The %s not is a loss supported" % c.loss)
eval_data_loader = None
num_classes = len(train_classes)
criterion = model.get_criterion(c, num_classes)
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
if args.restore_path:
checkpoint = load_fsspec(args.restore_path)
try:
model.load_state_dict(checkpoint["model"])
if "criterion" in checkpoint:
criterion.load_state_dict(checkpoint["criterion"])
except (KeyError, RuntimeError):
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group["lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
print(" > Model restored from step %d" % args.restore_step, flush=True)
else:
args.restore_step = 0
@ -206,7 +280,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda()
global_step = args.restore_step
_, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step)
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
if __name__ == "__main__":

View File

@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
"""
config_class = None
config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
for path in paths:
try:
config_class = find_module(path, config_name)

View File

@ -7,10 +7,10 @@ from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTr
@dataclass
class SpeakerEncoderConfig(BaseTrainingConfig):
"""Defines parameters for Speaker Encoder model."""
class BaseEncoderConfig(BaseTrainingConfig):
"""Defines parameters for a Generic Encoder model."""
model: str = "speaker_encoder"
model: str = None
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# model params
@ -27,34 +27,33 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
audio_augmentation: Dict = field(default_factory=lambda: {})
storage: Dict = field(
default_factory=lambda: {
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, # the size of the in-memory storage with respect to a single batch
}
)
# training params
max_train_step: int = 1000000 # end training when number of training steps reaches this value.
epochs: int = 10000
loss: str = "angleproto"
grad_clip: float = 3.0
lr: float = 0.0001
optimizer: str = "radam"
optimizer_params: Dict = field(default_factory=lambda: {
"betas": [0.9, 0.999],
"weight_decay": 0
})
lr_decay: bool = False
warmup_steps: int = 4000
wd: float = 1e-6
# logging params
tb_model_param_stats: bool = False
steps_plot_stats: int = 10
checkpoint: bool = True
save_step: int = 1000
print_step: int = 20
run_eval: bool = False
# data loader
num_speakers_in_batch: int = MISSING
num_utters_per_speaker: int = MISSING
num_classes_in_batch: int = MISSING
num_utter_per_class: int = MISSING
eval_num_classes_in_batch: int = None
eval_num_utter_per_class: int = None
num_loader_workers: int = MISSING
skip_speakers: bool = False
voice_len: float = 1.6
def check_values(self):

View File

@ -0,0 +1,12 @@
from dataclasses import asdict, dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
@dataclass
class EmotionEncoderConfig(BaseEncoderConfig):
"""Defines parameters for Emotion Encoder model."""
model: str = "emotion_encoder"
map_classid_to_classname: dict = None
class_name_key: str = "emotion_name"

View File

@ -0,0 +1,11 @@
from dataclasses import asdict, dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
@dataclass
class SpeakerEncoderConfig(BaseEncoderConfig):
"""Defines parameters for Speaker Encoder model."""
model: str = "speaker_encoder"
class_name_key: str = "speaker_name"

149
TTS/encoder/dataset.py Normal file
View File

@ -0,0 +1,149 @@
import random
import torch
from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset):
def __init__(
self,
config,
ap,
meta_data,
voice_len=1.6,
num_classes_in_batch=64,
num_utter_per_class=10,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
super().__init__()
self.config = config
self.items = meta_data
self.sample_rate = ap.sample_rate
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_class = num_utter_per_class
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.classes, self.items = self.__parse_items()
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
# Data Augmentation
self.augmentator = None
self.gaussian_augmentation_config = None
if augmentation_config:
self.data_augmentation_p = augmentation_config["p"]
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
self.augmentator = AugmentWAV(ap, augmentation_config)
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def __parse_items(self):
class_to_utters = {}
for item in self.items:
path_ = item["audio_file"]
class_name = item[self.config.class_name_key]
if class_name in class_to_utters.keys():
class_to_utters[class_name].append(path_)
else:
class_to_utters[class_name] = [
path_,
]
# skip classes with number of samples >= self.num_utter_per_class
class_to_utters = {
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
}
classes = list(class_to_utters.keys())
classes.sort()
new_items = []
for item in self.items:
path_ = item["audio_file"]
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
# ignore filtered classes
if class_name not in classes:
continue
# ignore small audios
if self.load_wav(path_).shape[0] - self.seq_len <= 0:
continue
new_items.append({"wav_file_path": path_, "class_name": class_name})
return classes, new_items
def __len__(self):
return len(self.items)
def get_num_classes(self):
return len(self.classes)
def get_class_list(self):
return self.classes
def set_classes(self, classes):
self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
def __getitem__(self, idx):
return self.items[idx]
def collate_fn(self, batch):
# get the batch class_ids
labels = []
feats = []
for item in batch:
utter_path = item["wav_file_path"]
class_name = item["class_name"]
# get classid
class_id = self.classname_to_classid[class_name]
# load wav file
wav = self.load_wav(utter_path)
offset = random.randint(0, wav.shape[0] - self.seq_len)
wav = wav[offset : offset + self.seq_len]
if self.augmentator is not None and self.data_augmentation_p:
if random.random() < self.data_augmentation_p:
wav = self.augmentator.apply_one(wav)
if not self.use_torch_spec:
mel = self.ap.melspectrogram(wav)
feats.append(torch.FloatTensor(mel))
else:
feats.append(torch.FloatTensor(wav))
labels.append(class_id)
feats = torch.stack(feats)
labels = torch.LongTensor(labels)
return feats, labels

View File

@ -189,6 +189,11 @@ class SoftmaxLoss(nn.Module):
return L
def inference(self, embedding):
x = self.fc(embedding)
activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
class_id = torch.argmax(activations)
return class_id
class SoftmaxAngleProtoLoss(nn.Module):
"""

View File

@ -0,0 +1,145 @@
import torch
import torchaudio
import numpy as np
from torch import nn
from TTS.utils.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from coqpit import Coqpit
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class BaseEncoder(nn.Module):
"""Base `encoder` class. Every new `encoder` model must inherit this.
It defines common `encoder` specific functions.
"""
# pylint: disable=W0102
def __init__(self):
super(BaseEncoder, self).__init__()
def get_torch_mel_spectrogram_class(self, audio_config):
return torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
)
)
@torch.no_grad()
def inference(self, x, l2_norm=True):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def get_criterion(self, c: Coqpit, num_classes=None):
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
else:
raise Exception("The %s not is a loss supported" % c.loss)
return criterion
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
try:
self.load_state_dict(state["model"])
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
del model_dict
# load the criterion for restore_path
if criterion is not None and "criterion" in state:
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
# instance and load the criterion for the encoder classifier in inference time
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
criterion.load_state_dict(state["criterion"])
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
if not eval:
return criterion, state["step"]
return criterion

View File

@ -0,0 +1,99 @@
import torch
from torch import nn
from TTS.encoder.models.base_encoder import BaseEncoder
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.proj_size = proj_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
def forward(self, x):
self.lstm.flatten_parameters()
o, (_, _) = self.lstm(x)
return self.linear(o)
class LSTMWithoutProjection(nn.Module):
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
super().__init__()
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
_, (hidden, _) = self.lstm(x)
return self.relu(self.linear(hidden[-1]))
class LSTMSpeakerEncoder(BaseEncoder):
def __init__(
self,
input_dim,
proj_dim=256,
lstm_dim=768,
num_lstm_layers=3,
use_lstm_with_projection=True,
use_torch_spec=False,
audio_config=None,
):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
layers = []
# choise LSTM layer
if use_lstm_with_projection:
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
for _ in range(num_lstm_layers - 1):
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
self.layers = nn.Sequential(*layers)
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
else:
self.torch_spec = None
self._init_layers()
def _init_layers(self):
for name, param in self.layers.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0.0)
elif "weight" in name:
nn.init.xavier_normal_(param)
def forward(self, x, l2_norm=True):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x.squeeze_(1)
x = self.torch_spec(x)
x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
d = d[:, -1]
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d

View File

@ -1,24 +1,8 @@
import numpy as np
import torch
import torchaudio
from torch import nn
# from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
from TTS.encoder.models.base_encoder import BaseEncoder
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
@ -71,7 +55,7 @@ class SEBasicBlock(nn.Module):
return out
class ResNetSpeakerEncoder(nn.Module):
class ResNetSpeakerEncoder(BaseEncoder):
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
Adapted from: https://github.com/clovaai/voxceleb_trainer
"""
@ -110,32 +94,7 @@ class ResNetSpeakerEncoder(nn.Module):
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
else:
self.torch_spec = None
@ -238,47 +197,3 @@ class ResNetSpeakerEncoder(nn.Module):
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x
@torch.no_grad()
def inference(self, x, l2_norm=False):
return self.forward(x, l2_norm)
@torch.no_grad()
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
# map to the waveform size
if self.use_torch_spec:
num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()
if eval:
self.eval()
assert not self.training

View File

@ -3,60 +3,15 @@ import glob
import os
import random
import re
from multiprocessing import Manager
import numpy as np
from scipy import signal
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class Storage(object):
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
# use multiprocessing for threading safe
self.storage = Manager().list()
self.maxsize = maxsize
self.num_speakers_in_batch = num_speakers_in_batch
self.num_threads = num_threads
self.ignore_last_batch = False
if storage_batchs >= 3:
self.ignore_last_batch = True
# used for fast random sample
self.safe_storage_size = self.maxsize - self.num_threads
if self.ignore_last_batch:
self.safe_storage_size -= self.num_speakers_in_batch
def __len__(self):
return len(self.storage)
def full(self):
return len(self.storage) >= self.maxsize
def append(self, item):
# if storage is full, remove an item
if self.full():
self.storage.pop(0)
self.storage.append(item)
def get_random_sample(self):
# safe storage size considering all threads remove one item from storage in same time
storage_size = len(self.storage) - self.num_threads
if self.ignore_last_batch:
storage_size -= self.num_speakers_in_batch
return self.storage[random.randint(0, storage_size)]
def get_random_sample_fast(self):
"""Call this method only when storage is full"""
return self.storage[random.randint(0, self.safe_storage_size)]
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
@ -209,7 +164,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
@ -217,6 +172,7 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}

View File

@ -0,0 +1,102 @@
import random
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
class SubsetSampler(Sampler):
"""
Samples elements sequentially from a given list of indices.
Args:
indices (list): a sequence of indices
"""
def __init__(self, indices):
super().__init__(indices)
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class PerfectBatchSampler(Sampler):
"""
Samples a mini-batch of indices for a balanced class batching
Args:
dataset_items(list): dataset items to sample from.
classes (list): list of classes of dataset_items to sample from.
batch_size (int): total number of samples to be sampled in a mini-batch.
num_gpus (int): number of GPU in the data parallel mode.
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
label_indices = {}
for idx, item in enumerate(dataset_items):
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:
label_indices[label].append(idx)
if shuffle:
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
else:
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
self._batch_size = batch_size
self._drop_last = drop_last
self._dp_devices = num_gpus
self._num_classes_in_batch = num_classes_in_batch
def __iter__(self):
batch = []
if self._num_classes_in_batch != len(self._samplers):
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
else:
valid_samplers_idx = None
iters = [iter(s) for s in self._samplers]
done = False
while True:
b = []
for i, it in enumerate(iters):
if valid_samplers_idx is not None and i not in valid_samplers_idx:
continue
idx = next(it, None)
if idx is None:
done = True
break
b.append(idx)
if done:
break
batch += b
if len(batch) == self._batch_size:
yield batch
batch = []
if valid_samplers_idx is not None:
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
if not self._drop_last:
if len(batch) > 0:
groups = len(batch) // self._num_classes_in_batch
if groups % self._dp_devices == 0:
yield batch
else:
batch = batch[:(groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
if len(batch) > 0:
yield batch
def __len__(self):
class_batch_size = self._batch_size // self._num_classes_in_batch
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)

View File

@ -29,14 +29,18 @@ colormap = (
)
def plot_embeddings(embeddings, num_utter_per_speaker):
embeddings = embeddings[: 10 * num_utter_per_speaker]
def plot_embeddings(embeddings, num_classes_in_batch):
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
# if necessary get just the first 10 classes
if num_classes_in_batch > 10:
num_classes_in_batch = 10
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
model = umap.UMAP()
projection = model.fit_transform(embeddings)
num_speakers = embeddings.shape[0] // num_utter_per_speaker
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
colors = [colormap[i] for i in ground_truth]
fig, ax = plt.subplots(figsize=(16, 10))
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
plt.gca().set_aspect("equal", "datalim")

View File

@ -1,118 +0,0 @@
{
"model_name": "lstm",
"run_name": "mueller91",
"run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
"audio":{
// Audio processing parameters
"num_mels": 40, // size of the mel spec frame.
"fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 400, // stft window length in ms.
"hop_length": 160, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
"reinit_layers": [],
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
"grad_clip": 3.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, //
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 1.6, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 20, // Number of steps to log traning on console.
"output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
"model": {
"input_dim": 40,
"proj_dim": 256,
"lstm_dim": 768,
"num_lstm_layers": 3,
"use_lstm_with_projection": true
},
"audio_augmentation": {
"p": 0,
//add a gaussian noise to the data in order to increase robustness
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
"p": 1, // propability of apply this method, 0 is disable
"min_amplitude": 0.0,
"max_amplitude": 1e-5
}
},
"storage": {
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
},
"datasets":
[
{
"name": "vctk_slim",
"path": "../../../audio-datasets/en/VCTK-Corpus/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "libri_tts",
"path": "../../../audio-datasets/en/LibriTTS/train-other-500",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "voxceleb1",
"path": "../../../audio-datasets/en/voxceleb1/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "voxceleb2",
"path": "../../../audio-datasets/en/voxceleb2/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "common_voice",
"path": "../../../audio-datasets/en/MozillaCommonVoice",
"meta_file_train": "train.tsv",
"meta_file_val": "test.tsv"
}
]
}

View File

@ -1,956 +0,0 @@
{
"model": "speaker_encoder",
"run_name": "speaker_encoder",
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
// AUDIO PARAMETERS
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"stft_pad_mode": "reflect",
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0,
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
"reinit_layers": [],
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
"grad_clip": 3.0, // upper limit for gradients for clipping.
"max_train_step": 1000000, // total number of steps to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 100, // number of steps to plot embeddings.
// Speakers config
"num_speakers_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save the best checkpoints in training.
"print_step": 50, // Number of steps to log traning on console.
"output_path": "../checkpoints/speaker_encoder/angleproto/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto/", // DATASET-RELATED: output path for all training outputs.
"audio_augmentation": {
"p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation
"rir":{
"rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
"conv_mode": "full"
},
"additive":{
"sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
// list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored
"speech":{
"min_snr_in_db": 13,
"max_snr_in_db": 20,
"min_num_noises": 2,
"max_num_noises": 3
},
"noise":{
"min_snr_in_db": 0,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
},
"music":{
"min_snr_in_db": 5,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
}
},
//add a gaussian noise to the data in order to increase robustness
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
"p": 0.5, // propability of apply this method, 0 is disable
"min_amplitude": 0.0,
"max_amplitude": 1e-5
}
},
"model_params": {
"model_name": "resnet",
"input_dim": 80,
"proj_dim": 512
},
"storage": {
"sample_from_storage_p": 0.5, // the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 35 // the size of the in-memory storage with respect to a single batch
},
"datasets":
[
{
"name": "voxceleb2",
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "voxceleb1",
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
}
]
}

View File

@ -1,957 +0,0 @@
{
"model": "speaker_encoder",
"run_name": "speaker_encoder",
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
// AUDIO PARAMETERS
"audio":{
// Audio processing parameters
"num_mels": 80, // size of the mel spec frame.
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
"stft_pad_mode": "reflect",
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0,
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
"reinit_layers": [],
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
"grad_clip": 3.0, // upper limit for gradients for clipping.
"max_train_step": 1000000, // total number of steps to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 100, // number of steps to plot embeddings.
// Speakers config
"num_speakers_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
"voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save the best checkpoints in training.
"print_step": 50, // Number of steps to log traning on console.
"output_path": "../../../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all/", // DATASET-RELATED: output path for all training outputs.
"audio_augmentation": {
"p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation
"rir":{
"rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
"conv_mode": "full"
},
"additive":{
"sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
// list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored
"speech":{
"min_snr_in_db": 13,
"max_snr_in_db": 20,
"min_num_noises": 2,
"max_num_noises": 3
},
"noise":{
"min_snr_in_db": 0,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
},
"music":{
"min_snr_in_db": 5,
"max_snr_in_db": 15,
"min_num_noises": 1,
"max_num_noises": 1
}
},
//add a gaussian noise to the data in order to increase robustness
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
"p": 0.5, // propability of apply this method, 0 is disable
"min_amplitude": 0.0,
"max_amplitude": 1e-5
}
},
"model_params": {
"model_name": "resnet",
"input_dim": 80,
"proj_dim": 512
},
"storage": {
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
"storage_size": 35 // the size of the in-memory storage with respect to a single batch
},
"datasets":
[
{
"name": "voxceleb2",
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "voxceleb1",
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
"meta_file_train": "dev.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
"meta_file_train": "train.tsv",
"meta_file_val": null
},
{
"name": "common_voice",
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
"meta_file_train": "dev.tsv",
"meta_file_val": null
}
]
}

View File

@ -1,243 +0,0 @@
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
class SpeakerEncoderDataset(Dataset):
def __init__(
self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
storage_size=1,
sample_from_storage_p=0.5,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
super().__init__()
self.items = meta_data
self.sample_rate = ap.sample_rate
self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch
self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch
self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
)
self.sample_from_storage_p = float(sample_from_storage_p)
speakers_aux = list(self.speakers)
speakers_aux.sort()
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
# Augmentation
self.augmentator = None
self.gaussian_augmentation_config = None
if augmentation_config:
self.data_augmentation_p = augmentation_config["p"]
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
self.augmentator = AugmentWAV(ap, augmentation_config)
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters")
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
path_ = i["audio_file"]
speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:
self.speaker_to_utters[speaker_] = [
path_,
]
if self.skip_speakers:
self.speaker_to_utters = {
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker
}
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
def __len__(self):
return int(1e10)
def get_num_speakers(self):
return len(self.speakers)
def __sample_speaker(self, ignore_speakers=None):
speaker = random.sample(self.speakers, 1)[0]
# if list of speakers_id is provide make sure that it's will be ignored
if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers:
while True:
speaker = random.sample(self.speakers, 1)[0]
if self.speakerid_to_classid[speaker] not in ignore_speakers:
break
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
else:
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
"""
Sample all M utterances for the given speaker.
"""
wavs = []
labels = []
for _ in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
# remove speakers that have num_utter less than 2
if len(self.speaker_to_utters[speaker]) > 1:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
if speaker in self.speakers:
self.speakers.remove(speaker)
speaker, _ = self.__sample_speaker()
continue
wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0:
break
if utter in self.speaker_to_utters[speaker]:
self.speaker_to_utters[speaker].remove(utter)
if self.augmentator is not None and self.data_augmentation_p:
if random.random() < self.data_augmentation_p:
wav = self.augmentator.apply_one(wav)
wavs.append(wav)
labels.append(self.speakerid_to_classid[speaker])
return wavs, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
speaker_id = self.speakerid_to_classid[speaker]
return speaker, speaker_id
def __load_from_disk_and_storage(self, speaker):
# don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
# put the newly loaded item into storage
self.storage.append((wavs_, labels_))
return wavs_, labels_
def collate_fn(self, batch):
# get the batch speaker_ids
batch = np.array(batch)
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
labels = []
feats = []
speakers = set()
for speaker, speaker_id in batch:
speaker_id = int(speaker_id)
# ensure that an speaker appears only once in the batch
if speaker_id in speakers:
# remove current speaker
if speaker_id in speakers_id_in_batch:
speakers_id_in_batch.remove(speaker_id)
speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch)
speaker_id = self.speakerid_to_classid[speaker]
speakers_id_in_batch.add(speaker_id)
if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full)
wavs_, labels_ = self.storage.get_random_sample_fast()
# force choose the current speaker or other not in batch
# It's necessary for ideal training with AngleProto and GE2E losses
if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id:
attempts = 0
while True:
wavs_, labels_ = self.storage.get_random_sample_fast()
if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch:
break
attempts += 1
# Try 5 times after that load from disk
if attempts >= 5:
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
break
else:
# don't sample from storage, but from HDD
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
# append speaker for control
speakers.add(labels_[0])
# remove current speaker and append other
if speaker_id in speakers_id_in_batch:
speakers_id_in_batch.remove(speaker_id)
speakers_id_in_batch.add(labels_[0])
# get a random subset of each of the wavs and extract mel spectrograms.
feats_ = []
for wav in wavs_:
offset = random.randint(0, wav.shape[0] - self.seq_len)
wav = wav[offset : offset + self.seq_len]
# add random gaussian noise
if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
if random.random() < self.gaussian_augmentation_config["p"]:
wav += np.random.normal(
self.gaussian_augmentation_config["min_amplitude"],
self.gaussian_augmentation_config["max_amplitude"],
size=len(wav),
)
if not self.use_torch_spec:
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))
else:
feats_.append(torch.FloatTensor(wav))
labels.append(torch.LongTensor(labels_))
feats.extend(feats_)
feats = torch.stack(feats)
labels = torch.stack(labels)
return feats, labels

View File

@ -1,189 +0,0 @@
import numpy as np
import torch
import torchaudio
from torch import nn
from TTS.speaker_encoder.models.resnet import PreEmphasis
from TTS.utils.io import load_fsspec
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.proj_size = proj_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
def forward(self, x):
self.lstm.flatten_parameters()
o, (_, _) = self.lstm(x)
return self.linear(o)
class LSTMWithoutProjection(nn.Module):
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
super().__init__()
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
_, (hidden, _) = self.lstm(x)
return self.relu(self.linear(hidden[-1]))
class LSTMSpeakerEncoder(nn.Module):
def __init__(
self,
input_dim,
proj_dim=256,
lstm_dim=768,
num_lstm_layers=3,
use_lstm_with_projection=True,
use_torch_spec=False,
audio_config=None,
):
super().__init__()
self.use_lstm_with_projection = use_lstm_with_projection
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
layers = []
# choise LSTM layer
if use_lstm_with_projection:
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
for _ in range(num_lstm_layers - 1):
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
self.layers = nn.Sequential(*layers)
else:
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
# TorchSTFT(
# n_fft=audio_config["fft_size"],
# hop_length=audio_config["hop_length"],
# win_length=audio_config["win_length"],
# sample_rate=audio_config["sample_rate"],
# window="hamming_window",
# mel_fmin=0.0,
# mel_fmax=None,
# use_htk=True,
# do_amp_to_db=False,
# n_mels=audio_config["num_mels"],
# power=2.0,
# use_mel=True,
# mel_norm=None,
# )
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
self._init_layers()
def _init_layers(self):
for name, param in self.layers.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0.0)
elif "weight" in name:
nn.init.xavier_normal_(param)
def forward(self, x, l2_norm=True):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
if self.use_torch_spec:
x.squeeze_(1)
x = self.torch_spec(x)
x = self.instancenorm(x).transpose(1, 2)
d = self.layers(x)
if self.use_lstm_with_projection:
d = d[:, -1]
if l2_norm:
d = torch.nn.functional.normalize(d, p=2, dim=1)
return d
@torch.no_grad()
def inference(self, x, l2_norm=True):
d = self.forward(x, l2_norm=l2_norm)
return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
max_len = x.shape[1]
if max_len < num_frames:
num_frames = max_len
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
frames_batch = []
for offset in offsets:
offset = int(offset)
end_offset = int(offset + num_frames)
frames = x[:, offset:end_offset]
frames_batch.append(frames)
frames_batch = torch.cat(frames_batch, dim=0)
embeddings = self.inference(frames_batch)
if return_mean:
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
return embeddings
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
"""
Generate embeddings for a batch of utterances
x: BxTxD
"""
num_overlap = num_frames * overlap
max_len = x.shape[1]
embed = None
num_iters = seq_lens / (num_frames - num_overlap)
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
frames = x[:, offset:end_offset]
if embed is None:
embed = self.inference(frames)
else:
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
return embed / num_iters
# pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if use_cuda:
self.cuda()
if eval:
self.eval()
assert not self.training

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

View File

@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig):
# dataset
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
# optimizer
optimizer: str = None
optimizer: str = "radam"
optimizer_params: dict = None
# scheduler
lr_scheduler: str = ""

View File

@ -441,6 +441,26 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
return [x.strip().split("|") for x in f.readlines()]
def emotion(root_path, meta_file, ignored_speakers=None):
"""Generic emotion dataset"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
if line.startswith("file_path"):
continue
cols = line.split(",")
wav_file = os.path.join(root_path, cols[0])
speaker_id = cols[1]
emotion_id = cols[2].replace("\n", "")
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
return items
def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
"""Normalizes the Baker meta data file to TTS format

View File

@ -9,7 +9,7 @@ import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
@ -269,7 +269,7 @@ class SpeakerManager:
"""
self.speaker_encoder_config = load_config(config_path)
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:

View File

@ -3,9 +3,9 @@ import unittest
import torch as T
from tests import get_tests_input_path
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
file_path = get_tests_input_path()

View File

@ -4,14 +4,14 @@ import shutil
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
from TTS.encoder.configs.speaker_encoder_config import SpeakerEncoderConfig
def run_test_train():
command = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.name ljspeech "
"--coqpit.datasets.0.name ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
@ -24,17 +24,21 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = SpeakerEncoderConfig(
batch_size=4,
num_speakers_in_batch=1,
num_utters_per_speaker=10,
num_loader_workers=0,
max_train_step=2,
num_classes_in_batch=4,
num_utter_per_class=2,
eval_num_classes_in_batch=4,
eval_num_utter_per_class=2,
num_loader_workers=1,
epochs=1,
print_step=1,
save_step=1,
save_step=2,
print_eval=True,
run_eval=True,
audio=BaseAudioConfig(num_mels=80),
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.loss = "ge2e"
config.save_json(config_path)
print(config)
@ -69,14 +73,14 @@ run_cli(command_train)
shutil.rmtree(continue_path)
# test model with ge2e loss function
config.loss = "ge2e"
config.save_json(config_path)
run_test_train()
# config.loss = "ge2e"
# config.save_json(config_path)
# run_test_train()
# test model with angleproto loss function
config.loss = "angleproto"
config.save_json(config_path)
run_test_train()
# config.loss = "angleproto"
# config.save_json(config_path)
# run_test_train()
# test model with softmaxproto loss function
config.loss = "softmaxproto"

View File

@ -6,8 +6,8 @@ import torch
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.speaker_encoder.utils.io import save_checkpoint
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor

View File

@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
spk2 += 1
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
def test_perfect_sampler(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])
sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

View File

@ -66,8 +66,8 @@
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
// TRAINING
"batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":1,
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size": 8,
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"data_dep_init_iter": 1,

View File

@ -36,8 +36,8 @@
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, //
"num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utter_per_class": 10, //
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"

View File

@ -61,8 +61,8 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING
"batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":1,
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size": 8,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.

View File

@ -61,8 +61,8 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING
"batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":1,
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size": 8,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.

View File

@ -7,7 +7,7 @@ from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
from TTS.tts.utils.speakers import SpeakerManager