mirror of https://github.com/coqui-ai/TTS.git
REBASED: Transform Speaker Encoder in a Generic Encoder and Implement Emotion Encoder training support (#1349)
* Rename Speaker encoder module to encoder * Add a generic emotion dataset formatter * Transform the Speaker Encoder dataset to a generic dataset and create emotion encoder config * Add class map in emotion config * Add Base encoder config * Add evaluation encoder script * Fix the bug in plot_embeddings * Enable Weight decay for encoder training * Add argumnet to disable storage * Add Perfect Sampler and remove storage * Add evaluation during encoder training * Fix lint checks * Remove useless config parameter * Active evaluation in speaker encoder test and use multispeaker dataset for this test * Unit tests fixs * Remove useless tests for speedup the aux_tests * Use get_optimizer in Encoder * Add BaseEncoder Class * Fix the unitests * Add Perfect Batch Sampler unit test * Add compute encoder accuracy in a functionpull/1395/head
parent
36e9ea2f97
commit
f81892483d
|
@ -42,33 +42,35 @@ c_dataset = load_config(args.config_dataset_path)
|
|||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path,
|
||||
encoder_config_path=args.config_path,
|
||||
d_vectors_file_path=args.old_file,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
speaker_mapping = {}
|
||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||
if isinstance(wav_file, list):
|
||||
speaker_name = wav_file[2]
|
||||
wav_file = wav_file[1]
|
||||
if isinstance(wav_file, dict):
|
||||
class_name = wav_file[class_name_key]
|
||||
wav_file = wav_file["audio_file"]
|
||||
else:
|
||||
speaker_name = None
|
||||
class_name = None
|
||||
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
|
||||
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
|
||||
# get the embedding from the old file
|
||||
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
|
||||
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
|
||||
else:
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||
speaker_mapping[wav_file_name]["name"] = class_name
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd
|
||||
|
||||
if speaker_mapping:
|
||||
|
@ -81,5 +83,5 @@ if speaker_mapping:
|
|||
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
||||
|
||||
# pylint: disable=W0212
|
||||
speaker_manager._save_json(mapping_file_path, speaker_mapping)
|
||||
encoder_manager._save_json(mapping_file_path, speaker_mapping)
|
||||
print("Speaker embeddings saved at:", mapping_file_path)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
import argparse
|
||||
import torch
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
|
||||
|
||||
class_acc_dict = {}
|
||||
|
||||
# compute embeddings for all wav_files
|
||||
for item in tqdm(dataset_items):
|
||||
class_name = item[class_name_key]
|
||||
wav_file = item["audio_file"]
|
||||
|
||||
# extract the embedding
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||
if encoder_manager.use_cuda:
|
||||
embedding = embedding.cuda()
|
||||
|
||||
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
|
||||
predicted_label = map_classid_to_classname[str(class_id)]
|
||||
else:
|
||||
predicted_label = None
|
||||
|
||||
if class_name is not None and predicted_label is not None:
|
||||
is_equal = int(class_name == predicted_label)
|
||||
if class_name not in class_acc_dict:
|
||||
class_acc_dict[class_name] = [is_equal]
|
||||
else:
|
||||
class_acc_dict[class_name].append(is_equal)
|
||||
else:
|
||||
raise RuntimeError("Error: class_name or/and predicted_label are None")
|
||||
|
||||
acc_avg = 0
|
||||
for key, values in class_acc_dict.items():
|
||||
acc = sum(values)/len(values)
|
||||
print("Class", key, "Accuracy:", acc)
|
||||
acc_avg += acc
|
||||
|
||||
print("Average Accuracy:", acc_avg/len(class_acc_dict))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute the accuracy of the encoder.\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
||||
parser.add_argument(
|
||||
"config_path",
|
||||
type=str,
|
||||
help="Path to model config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"config_dataset_path",
|
||||
type=str,
|
||||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
c_dataset = load_config(args.config_dataset_path)
|
||||
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
items = meta_data_train + meta_data_eval
|
||||
|
||||
enc_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
compute_encoder_accuracy(items, enc_manager)
|
|
@ -10,16 +10,16 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.training import init_training
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.io import copy_model_files
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
@ -32,164 +32,238 @@ print(" > Number of GPUs: ", num_gpus)
|
|||
|
||||
|
||||
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
|
||||
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
|
||||
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
|
||||
|
||||
dataset = EncoderDataset(
|
||||
c,
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=c.voice_len,
|
||||
num_utter_per_class=num_utter_per_class,
|
||||
num_classes_in_batch=num_classes_in_batch,
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation if not is_val else None,
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
)
|
||||
# get classes list
|
||||
classes = dataset.get_class_list()
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
dataset.items,
|
||||
classes,
|
||||
batch_size=num_classes_in_batch*num_utter_per_class, # total batch size
|
||||
num_classes_in_batch=num_classes_in_batch,
|
||||
num_gpus=1,
|
||||
shuffle=not is_val,
|
||||
drop_last=True)
|
||||
|
||||
if len(classes) < num_classes_in_batch:
|
||||
if is_val:
|
||||
raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !")
|
||||
raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !")
|
||||
|
||||
# set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
|
||||
if is_val:
|
||||
loader = None
|
||||
else:
|
||||
dataset = SpeakerEncoderDataset(
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=c.voice_len,
|
||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||
skip_speakers=c.skip_speakers,
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation,
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
)
|
||||
dataset.set_classes(train_classes)
|
||||
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.num_speakers_in_batch,
|
||||
shuffle=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
return loader, dataset.get_num_speakers()
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
num_workers=c.num_loader_workers,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
|
||||
return loader, classes, dataset.get_map_classid_to_classname()
|
||||
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||
def evaluation(model, criterion, data_loader, global_step):
|
||||
eval_loss = 0
|
||||
for _, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
# setup input data
|
||||
inputs, labels = data
|
||||
|
||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||
labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape)
|
||||
inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels)
|
||||
|
||||
eval_loss += loss.item()
|
||||
|
||||
eval_avg_loss = eval_loss/len(data_loader)
|
||||
# save stats
|
||||
dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
|
||||
# plot the last batch in the evaluation
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.eval_figures(global_step, figures)
|
||||
return eval_avg_loss
|
||||
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
best_loss = float("inf")
|
||||
avg_loss = 0
|
||||
avg_loss_all = 0
|
||||
avg_loader_time = 0
|
||||
end_time = time.time()
|
||||
for epoch in range(c.epochs):
|
||||
tot_loss = 0
|
||||
epoch_time = 0
|
||||
for _, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
for _, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
# setup input data
|
||||
inputs, labels = data
|
||||
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||
labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
||||
inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
||||
# ToDo: move it to a unit test
|
||||
# labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
|
||||
# inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
|
||||
# idx = 0
|
||||
# for j in range(0, c.num_classes_in_batch, 1):
|
||||
# for i in range(j, len(labels), c.num_classes_in_batch):
|
||||
# if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
|
||||
# print("Invalid")
|
||||
# print(labels)
|
||||
# exit()
|
||||
# idx += 1
|
||||
# labels = labels_converted
|
||||
# inputs = inputs_converted
|
||||
|
||||
# setup input data
|
||||
inputs, labels = data
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
inputs = inputs.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
# forward pass model
|
||||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
# loss computation
|
||||
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# Averaged Loss and Averaged Loader Time
|
||||
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item()
|
||||
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
|
||||
avg_loader_time = (
|
||||
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
|
||||
if avg_loader_time != 0
|
||||
else loader_time
|
||||
)
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
# acumulate the total epoch loss
|
||||
tot_loss += loss.item()
|
||||
|
||||
if global_step % c.steps_plot_stats == 0:
|
||||
# Plot Training Epoch Stats
|
||||
train_stats = {
|
||||
"loss": avg_loss,
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time,
|
||||
"avg_loader_time": avg_loader_time,
|
||||
}
|
||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
# FIXME: not constant
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||
}
|
||||
dashboard_logger.train_figures(global_step, figures)
|
||||
|
||||
if global_step % c.print_step == 0:
|
||||
print(
|
||||
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
|
||||
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
|
||||
),
|
||||
flush=True,
|
||||
# Averaged Loader Time
|
||||
num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
|
||||
avg_loader_time = (
|
||||
1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
|
||||
if avg_loader_time != 0
|
||||
else loader_time
|
||||
)
|
||||
avg_loss_all += avg_loss
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
if global_step >= c.max_train_step or global_step % c.save_step == 0:
|
||||
# save best model only
|
||||
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step)
|
||||
avg_loss_all = 0
|
||||
if global_step >= c.max_train_step:
|
||||
break
|
||||
if global_step % c.steps_plot_stats == 0:
|
||||
# Plot Training Epoch Stats
|
||||
train_stats = {
|
||||
"loss": loss.item(),
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time,
|
||||
"avg_loader_time": avg_loader_time,
|
||||
}
|
||||
dashboard_logger.train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
|
||||
}
|
||||
dashboard_logger.train_figures(global_step, figures)
|
||||
|
||||
end_time = time.time()
|
||||
if global_step % c.print_step == 0:
|
||||
print(
|
||||
" | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
|
||||
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
|
||||
global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return avg_loss, global_step
|
||||
if global_step % c.save_step == 0:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
print("")
|
||||
print(
|
||||
">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
|
||||
"EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
|
||||
epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
# evaluation
|
||||
if c.run_eval:
|
||||
model.eval()
|
||||
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
|
||||
print("\n\n")
|
||||
print("--> EVAL PERFORMANCE")
|
||||
print(
|
||||
" | > Epoch:{} AvgLoss: {:.5f} ".format(
|
||||
epoch, eval_loss
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
# save the best checkpoint
|
||||
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
|
||||
model.train()
|
||||
|
||||
return best_loss, global_step
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train
|
||||
global meta_data_eval
|
||||
global train_classes
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = setup_speaker_encoder_model(c)
|
||||
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
|
||||
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers)
|
||||
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
|
||||
if c.run_eval:
|
||||
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
eval_data_loader = None
|
||||
|
||||
num_classes = len(train_classes)
|
||||
criterion = model.get_criterion(c, num_classes)
|
||||
|
||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||
c.map_classid_to_classname = map_classid_to_classname
|
||||
copy_model_files(c, OUT_PATH)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = load_fsspec(args.restore_path)
|
||||
try:
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if "criterion" in checkpoint:
|
||||
criterion.load_state_dict(checkpoint["criterion"])
|
||||
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion)
|
||||
print(" > Model restored from step %d" % args.restore_step, flush=True)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -206,7 +280,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
criterion.cuda()
|
||||
|
||||
global_step = args.restore_step
|
||||
_, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step)
|
||||
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
|
|||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, config_name)
|
||||
|
|
|
@ -7,10 +7,10 @@ from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTr
|
|||
|
||||
|
||||
@dataclass
|
||||
class SpeakerEncoderConfig(BaseTrainingConfig):
|
||||
"""Defines parameters for Speaker Encoder model."""
|
||||
class BaseEncoderConfig(BaseTrainingConfig):
|
||||
"""Defines parameters for a Generic Encoder model."""
|
||||
|
||||
model: str = "speaker_encoder"
|
||||
model: str = None
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# model params
|
||||
|
@ -27,34 +27,33 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
|||
|
||||
audio_augmentation: Dict = field(default_factory=lambda: {})
|
||||
|
||||
storage: Dict = field(
|
||||
default_factory=lambda: {
|
||||
"sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, # the size of the in-memory storage with respect to a single batch
|
||||
}
|
||||
)
|
||||
|
||||
# training params
|
||||
max_train_step: int = 1000000 # end training when number of training steps reaches this value.
|
||||
epochs: int = 10000
|
||||
loss: str = "angleproto"
|
||||
grad_clip: float = 3.0
|
||||
lr: float = 0.0001
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: Dict = field(default_factory=lambda: {
|
||||
"betas": [0.9, 0.999],
|
||||
"weight_decay": 0
|
||||
})
|
||||
lr_decay: bool = False
|
||||
warmup_steps: int = 4000
|
||||
wd: float = 1e-6
|
||||
|
||||
# logging params
|
||||
tb_model_param_stats: bool = False
|
||||
steps_plot_stats: int = 10
|
||||
checkpoint: bool = True
|
||||
save_step: int = 1000
|
||||
print_step: int = 20
|
||||
run_eval: bool = False
|
||||
|
||||
# data loader
|
||||
num_speakers_in_batch: int = MISSING
|
||||
num_utters_per_speaker: int = MISSING
|
||||
num_classes_in_batch: int = MISSING
|
||||
num_utter_per_class: int = MISSING
|
||||
eval_num_classes_in_batch: int = None
|
||||
eval_num_utter_per_class: int = None
|
||||
|
||||
num_loader_workers: int = MISSING
|
||||
skip_speakers: bool = False
|
||||
voice_len: float = 1.6
|
||||
|
||||
def check_values(self):
|
|
@ -0,0 +1,12 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmotionEncoderConfig(BaseEncoderConfig):
|
||||
"""Defines parameters for Emotion Encoder model."""
|
||||
|
||||
model: str = "emotion_encoder"
|
||||
map_classid_to_classname: dict = None
|
||||
class_name_key: str = "emotion_name"
|
|
@ -0,0 +1,11 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerEncoderConfig(BaseEncoderConfig):
|
||||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "speaker_encoder"
|
||||
class_name_key: str = "speaker_name"
|
|
@ -0,0 +1,149 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.encoder.utils.generic_utils import AugmentWAV
|
||||
|
||||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_classes_in_batch=64,
|
||||
num_utter_per_class=10,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
use_torch_spec=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
seq_len (int): voice segment length in seconds.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_utter_per_class = num_utter_per_class
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.classes, self.items = self.__parse_items()
|
||||
|
||||
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
||||
|
||||
# Data Augmentation
|
||||
self.augmentator = None
|
||||
self.gaussian_augmentation_config = None
|
||||
if augmentation_config:
|
||||
self.data_augmentation_p = augmentation_config["p"]
|
||||
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
|
||||
self.augmentator = AugmentWAV(ap, augmentation_config)
|
||||
|
||||
if "gaussian" in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Classes per Batch: {num_classes_in_batch}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num Classes: {len(self.classes)}")
|
||||
print(f" | > Classes: {self.classes}")
|
||||
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
return audio
|
||||
|
||||
def __parse_items(self):
|
||||
class_to_utters = {}
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item[self.config.class_name_key]
|
||||
if class_name in class_to_utters.keys():
|
||||
class_to_utters[class_name].append(path_)
|
||||
else:
|
||||
class_to_utters[class_name] = [
|
||||
path_,
|
||||
]
|
||||
|
||||
# skip classes with number of samples >= self.num_utter_per_class
|
||||
class_to_utters = {
|
||||
k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class
|
||||
}
|
||||
|
||||
classes = list(class_to_utters.keys())
|
||||
classes.sort()
|
||||
|
||||
new_items = []
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
|
||||
# ignore filtered classes
|
||||
if class_name not in classes:
|
||||
continue
|
||||
# ignore small audios
|
||||
if self.load_wav(path_).shape[0] - self.seq_len <= 0:
|
||||
continue
|
||||
|
||||
new_items.append({"wav_file_path": path_, "class_name": class_name})
|
||||
|
||||
return classes, new_items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def get_class_list(self):
|
||||
return self.classes
|
||||
def set_classes(self, classes):
|
||||
self.classes = classes
|
||||
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
|
||||
|
||||
|
||||
def get_map_classid_to_classname(self):
|
||||
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.items[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# get the batch class_ids
|
||||
labels = []
|
||||
feats = []
|
||||
for item in batch:
|
||||
utter_path = item["wav_file_path"]
|
||||
class_name = item["class_name"]
|
||||
|
||||
# get classid
|
||||
class_id = self.classname_to_classid[class_name]
|
||||
# load wav file
|
||||
wav = self.load_wav(utter_path)
|
||||
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||||
wav = wav[offset : offset + self.seq_len]
|
||||
|
||||
if self.augmentator is not None and self.data_augmentation_p:
|
||||
if random.random() < self.data_augmentation_p:
|
||||
wav = self.augmentator.apply_one(wav)
|
||||
|
||||
if not self.use_torch_spec:
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
feats.append(torch.FloatTensor(mel))
|
||||
else:
|
||||
feats.append(torch.FloatTensor(wav))
|
||||
|
||||
labels.append(class_id)
|
||||
|
||||
feats = torch.stack(feats)
|
||||
labels = torch.LongTensor(labels)
|
||||
|
||||
return feats, labels
|
|
@ -189,6 +189,11 @@ class SoftmaxLoss(nn.Module):
|
|||
|
||||
return L
|
||||
|
||||
def inference(self, embedding):
|
||||
x = self.fc(embedding)
|
||||
activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
|
||||
class_id = torch.argmax(activations)
|
||||
return class_id
|
||||
|
||||
class SoftmaxAngleProtoLoss(nn.Module):
|
||||
"""
|
|
@ -0,0 +1,145 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
from coqpit import Coqpit
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""Base `encoder` class. Every new `encoder` model must inherit this.
|
||||
|
||||
It defines common `encoder` specific functions.
|
||||
"""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(self):
|
||||
super(BaseEncoder, self).__init__()
|
||||
|
||||
def get_torch_mel_spectrogram_class(self, audio_config):
|
||||
return torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=True):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
return embeddings
|
||||
|
||||
def get_criterion(self, c: Coqpit, num_classes=None):
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
return criterion
|
||||
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
self.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
|
||||
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if criterion is not None:
|
||||
criterion = criterion.cuda()
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
if not eval:
|
||||
return criterion, state["step"]
|
||||
return criterion
|
|
@ -0,0 +1,99 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, proj_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.proj_size = proj_size
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
||||
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
o, (_, _) = self.lstm(x)
|
||||
return self.linear(o)
|
||||
|
||||
|
||||
class LSTMWithoutProjection(nn.Module):
|
||||
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
|
||||
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
_, (hidden, _) = self.lstm(x)
|
||||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class LSTMSpeakerEncoder(BaseEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
proj_dim=256,
|
||||
lstm_dim=768,
|
||||
num_lstm_layers=3,
|
||||
use_lstm_with_projection=True,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
layers = []
|
||||
# choise LSTM layer
|
||||
if use_lstm_with_projection:
|
||||
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
|
||||
for _ in range(num_lstm_layers - 1):
|
||||
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
else:
|
||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for name, param in self.layers.named_parameters():
|
||||
if "bias" in name:
|
||||
nn.init.constant_(param, 0.0)
|
||||
elif "weight" in name:
|
||||
nn.init.xavier_normal_(param)
|
||||
|
||||
def forward(self, x, l2_norm=True):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x.squeeze_(1)
|
||||
x = self.torch_spec(x)
|
||||
x = self.instancenorm(x).transpose(1, 2)
|
||||
d = self.layers(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = d[:, -1]
|
||||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
|
@ -1,24 +1,8 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
# from TTS.utils.audio import TorchSTFT
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
super().__init__()
|
||||
self.coefficient = coefficient
|
||||
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
from TTS.encoder.models.base_encoder import BaseEncoder
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
|
@ -71,7 +55,7 @@ class SEBasicBlock(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
class ResNetSpeakerEncoder(BaseEncoder):
|
||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||
"""
|
||||
|
@ -110,32 +94,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
|
@ -238,47 +197,3 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=False):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
# map to the waveform size
|
||||
if self.use_torch_spec:
|
||||
num_frames = num_frames * self.audio_config["hop_length"]
|
||||
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch, l2_norm=l2_norm)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
return embeddings
|
||||
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
|
@ -3,60 +3,15 @@ import glob
|
|||
import os
|
||||
import random
|
||||
import re
|
||||
from multiprocessing import Manager
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
class Storage(object):
|
||||
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
||||
# use multiprocessing for threading safe
|
||||
self.storage = Manager().list()
|
||||
self.maxsize = maxsize
|
||||
self.num_speakers_in_batch = num_speakers_in_batch
|
||||
self.num_threads = num_threads
|
||||
self.ignore_last_batch = False
|
||||
|
||||
if storage_batchs >= 3:
|
||||
self.ignore_last_batch = True
|
||||
|
||||
# used for fast random sample
|
||||
self.safe_storage_size = self.maxsize - self.num_threads
|
||||
if self.ignore_last_batch:
|
||||
self.safe_storage_size -= self.num_speakers_in_batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.storage)
|
||||
|
||||
def full(self):
|
||||
return len(self.storage) >= self.maxsize
|
||||
|
||||
def append(self, item):
|
||||
# if storage is full, remove an item
|
||||
if self.full():
|
||||
self.storage.pop(0)
|
||||
|
||||
self.storage.append(item)
|
||||
|
||||
def get_random_sample(self):
|
||||
# safe storage size considering all threads remove one item from storage in same time
|
||||
storage_size = len(self.storage) - self.num_threads
|
||||
|
||||
if self.ignore_last_batch:
|
||||
storage_size -= self.num_speakers_in_batch
|
||||
|
||||
return self.storage[random.randint(0, storage_size)]
|
||||
|
||||
def get_random_sample_fast(self):
|
||||
"""Call this method only when storage is full"""
|
||||
return self.storage[random.randint(0, self.safe_storage_size)]
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
def __init__(self, ap, augmentation_config):
|
||||
|
||||
|
@ -209,7 +164,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s
|
|||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
|
@ -217,6 +172,7 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path
|
|||
"optimizer": optimizer.state_dict(),
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
|
@ -0,0 +1,102 @@
|
|||
import random
|
||||
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
"""
|
||||
Samples elements sequentially from a given list of indices.
|
||||
|
||||
Args:
|
||||
indices (list): a sequence of indices
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
super().__init__(indices)
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in range(len(self.indices)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class PerfectBatchSampler(Sampler):
|
||||
"""
|
||||
Samples a mini-batch of indices for a balanced class batching
|
||||
|
||||
Args:
|
||||
dataset_items(list): dataset items to sample from.
|
||||
classes (list): list of classes of dataset_items to sample from.
|
||||
batch_size (int): total number of samples to be sampled in a mini-batch.
|
||||
num_gpus (int): number of GPU in the data parallel mode.
|
||||
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
|
||||
drop_last (bool): if True, drops last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
|
||||
super().__init__(dataset_items)
|
||||
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
|
||||
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')
|
||||
|
||||
label_indices = {}
|
||||
for idx, item in enumerate(dataset_items):
|
||||
label = item[label_key]
|
||||
if label not in label_indices.keys():
|
||||
label_indices[label] = [idx]
|
||||
else:
|
||||
label_indices[label].append(idx)
|
||||
|
||||
if shuffle:
|
||||
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
|
||||
else:
|
||||
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._drop_last = drop_last
|
||||
self._dp_devices = num_gpus
|
||||
self._num_classes_in_batch = num_classes_in_batch
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch = []
|
||||
if self._num_classes_in_batch != len(self._samplers):
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
else:
|
||||
valid_samplers_idx = None
|
||||
|
||||
iters = [iter(s) for s in self._samplers]
|
||||
done = False
|
||||
|
||||
while True:
|
||||
b = []
|
||||
for i, it in enumerate(iters):
|
||||
if valid_samplers_idx is not None and i not in valid_samplers_idx:
|
||||
continue
|
||||
idx = next(it, None)
|
||||
if idx is None:
|
||||
done = True
|
||||
break
|
||||
b.append(idx)
|
||||
if done:
|
||||
break
|
||||
batch += b
|
||||
if len(batch) == self._batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if valid_samplers_idx is not None:
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
|
||||
if not self._drop_last:
|
||||
if len(batch) > 0:
|
||||
groups = len(batch) // self._num_classes_in_batch
|
||||
if groups % self._dp_devices == 0:
|
||||
yield batch
|
||||
else:
|
||||
batch = batch[:(groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
|
@ -29,14 +29,18 @@ colormap = (
|
|||
)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings, num_utter_per_speaker):
|
||||
embeddings = embeddings[: 10 * num_utter_per_speaker]
|
||||
def plot_embeddings(embeddings, num_classes_in_batch):
|
||||
num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
|
||||
|
||||
# if necessary get just the first 10 classes
|
||||
if num_classes_in_batch > 10:
|
||||
num_classes_in_batch = 10
|
||||
embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
|
||||
|
||||
model = umap.UMAP()
|
||||
projection = model.fit_transform(embeddings)
|
||||
num_speakers = embeddings.shape[0] // num_utter_per_speaker
|
||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
|
||||
ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
|
||||
plt.gca().set_aspect("equal", "datalim")
|
|
@ -1,118 +0,0 @@
|
|||
|
||||
{
|
||||
"model_name": "lstm",
|
||||
"run_name": "mueller91",
|
||||
"run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 40, // size of the mel spec frame.
|
||||
"fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"win_length": 400, // stft window length in ms.
|
||||
"hop_length": 160, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
"reinit_layers": [],
|
||||
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
|
||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utters_per_speaker": 10, //
|
||||
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
|
||||
"voice_len": 1.6, // number of seconds for each training instance
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 20, // Number of steps to log traning on console.
|
||||
"output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
||||
"model": {
|
||||
"input_dim": 40,
|
||||
"proj_dim": 256,
|
||||
"lstm_dim": 768,
|
||||
"num_lstm_layers": 3,
|
||||
"use_lstm_with_projection": true
|
||||
},
|
||||
|
||||
"audio_augmentation": {
|
||||
"p": 0,
|
||||
//add a gaussian noise to the data in order to increase robustness
|
||||
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
|
||||
"p": 1, // propability of apply this method, 0 is disable
|
||||
"min_amplitude": 0.0,
|
||||
"max_amplitude": 1e-5
|
||||
}
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
|
||||
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
|
||||
},
|
||||
"datasets":
|
||||
[
|
||||
{
|
||||
"name": "vctk_slim",
|
||||
"path": "../../../audio-datasets/en/VCTK-Corpus/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "libri_tts",
|
||||
"path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "libri_tts",
|
||||
"path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "libri_tts",
|
||||
"path": "../../../audio-datasets/en/LibriTTS/train-other-500",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "voxceleb1",
|
||||
"path": "../../../audio-datasets/en/voxceleb1/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "voxceleb2",
|
||||
"path": "../../../audio-datasets/en/voxceleb2/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "../../../audio-datasets/en/MozillaCommonVoice",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": "test.tsv"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1,956 +0,0 @@
|
|||
{
|
||||
"model": "speaker_encoder",
|
||||
"run_name": "speaker_encoder",
|
||||
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"stft_pad_mode": "reflect",
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 20.0,
|
||||
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
"reinit_layers": [],
|
||||
|
||||
"loss": "angleproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
|
||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||
"max_train_step": 1000000, // total number of steps to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 100, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 200, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save the best checkpoints in training.
|
||||
"print_step": 50, // Number of steps to log traning on console.
|
||||
"output_path": "../checkpoints/speaker_encoder/angleproto/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
||||
"audio_augmentation": {
|
||||
"p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation
|
||||
"rir":{
|
||||
"rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
|
||||
"conv_mode": "full"
|
||||
},
|
||||
"additive":{
|
||||
"sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
|
||||
// list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored
|
||||
"speech":{
|
||||
"min_snr_in_db": 13,
|
||||
"max_snr_in_db": 20,
|
||||
"min_num_noises": 2,
|
||||
"max_num_noises": 3
|
||||
},
|
||||
"noise":{
|
||||
"min_snr_in_db": 0,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
},
|
||||
"music":{
|
||||
"min_snr_in_db": 5,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
}
|
||||
},
|
||||
//add a gaussian noise to the data in order to increase robustness
|
||||
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
|
||||
"p": 0.5, // propability of apply this method, 0 is disable
|
||||
"min_amplitude": 0.0,
|
||||
"max_amplitude": 1e-5
|
||||
}
|
||||
},
|
||||
"model_params": {
|
||||
"model_name": "resnet",
|
||||
"input_dim": 80,
|
||||
"proj_dim": 512
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.5, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 35 // the size of the in-memory storage with respect to a single batch
|
||||
},
|
||||
"datasets":
|
||||
[
|
||||
{
|
||||
"name": "voxceleb2",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "voxceleb1",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
}
|
||||
|
||||
]
|
||||
}
|
|
@ -1,957 +0,0 @@
|
|||
|
||||
{
|
||||
"model": "speaker_encoder",
|
||||
"run_name": "speaker_encoder",
|
||||
"run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev",
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"stft_pad_mode": "reflect",
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 20.0,
|
||||
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
"reinit_layers": [],
|
||||
|
||||
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
|
||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||
"max_train_step": 1000000, // total number of steps to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 100, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 200, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save the best checkpoints in training.
|
||||
"print_step": 50, // Number of steps to log traning on console.
|
||||
"output_path": "../../../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
||||
"audio_augmentation": {
|
||||
"p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation
|
||||
"rir":{
|
||||
"rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
|
||||
"conv_mode": "full"
|
||||
},
|
||||
"additive":{
|
||||
"sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
|
||||
// list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored
|
||||
"speech":{
|
||||
"min_snr_in_db": 13,
|
||||
"max_snr_in_db": 20,
|
||||
"min_num_noises": 2,
|
||||
"max_num_noises": 3
|
||||
},
|
||||
"noise":{
|
||||
"min_snr_in_db": 0,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
},
|
||||
"music":{
|
||||
"min_snr_in_db": 5,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
}
|
||||
},
|
||||
//add a gaussian noise to the data in order to increase robustness
|
||||
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
|
||||
"p": 0.5, // propability of apply this method, 0 is disable
|
||||
"min_amplitude": 0.0,
|
||||
"max_amplitude": 1e-5
|
||||
}
|
||||
},
|
||||
"model_params": {
|
||||
"model_name": "resnet",
|
||||
"input_dim": 80,
|
||||
"proj_dim": 512
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 35 // the size of the in-memory storage with respect to a single batch
|
||||
},
|
||||
"datasets":
|
||||
[
|
||||
{
|
||||
"name": "voxceleb2",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "voxceleb1",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
},
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
|
||||
"meta_file_train": "train.tsv",
|
||||
"meta_file_val": null
|
||||
},
|
||||
|
||||
{
|
||||
"name": "common_voice",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv",
|
||||
"meta_file_train": "dev.tsv",
|
||||
"meta_file_val": null
|
||||
}
|
||||
|
||||
]
|
||||
}
|
|
@ -1,243 +0,0 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage
|
||||
|
||||
|
||||
class SpeakerEncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_speakers_in_batch=64,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
num_utter_per_speaker=10,
|
||||
skip_speakers=False,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
use_torch_spec=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
seq_len (int): voice segment length in seconds.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super().__init__()
|
||||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_speakers_in_batch = num_speakers_in_batch
|
||||
self.num_utter_per_speaker = num_utter_per_speaker
|
||||
self.skip_speakers = skip_speakers
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.__parse_items()
|
||||
storage_max_size = storage_size * num_speakers_in_batch
|
||||
self.storage = Storage(
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
|
||||
)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
|
||||
speakers_aux = list(self.speakers)
|
||||
speakers_aux.sort()
|
||||
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
|
||||
|
||||
# Augmentation
|
||||
self.augmentator = None
|
||||
self.gaussian_augmentation_config = None
|
||||
if augmentation_config:
|
||||
self.data_augmentation_p = augmentation_config["p"]
|
||||
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
|
||||
self.augmentator = AugmentWAV(ap, augmentation_config)
|
||||
|
||||
if "gaussian" in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
||||
print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters")
|
||||
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num speakers: {len(self.speakers)}")
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
return audio
|
||||
|
||||
def __parse_items(self):
|
||||
self.speaker_to_utters = {}
|
||||
for i in self.items:
|
||||
path_ = i["audio_file"]
|
||||
speaker_ = i["speaker_name"]
|
||||
if speaker_ in self.speaker_to_utters.keys():
|
||||
self.speaker_to_utters[speaker_].append(path_)
|
||||
else:
|
||||
self.speaker_to_utters[speaker_] = [
|
||||
path_,
|
||||
]
|
||||
|
||||
if self.skip_speakers:
|
||||
self.speaker_to_utters = {
|
||||
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker
|
||||
}
|
||||
|
||||
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
|
||||
|
||||
def __len__(self):
|
||||
return int(1e10)
|
||||
|
||||
def get_num_speakers(self):
|
||||
return len(self.speakers)
|
||||
|
||||
def __sample_speaker(self, ignore_speakers=None):
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
# if list of speakers_id is provide make sure that it's will be ignored
|
||||
if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers:
|
||||
while True:
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
if self.speakerid_to_classid[speaker] not in ignore_speakers:
|
||||
break
|
||||
|
||||
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||||
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||||
else:
|
||||
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
|
||||
return speaker, utters
|
||||
|
||||
def __sample_speaker_utterances(self, speaker):
|
||||
"""
|
||||
Sample all M utterances for the given speaker.
|
||||
"""
|
||||
wavs = []
|
||||
labels = []
|
||||
for _ in range(self.num_utter_per_speaker):
|
||||
# TODO:dummy but works
|
||||
while True:
|
||||
# remove speakers that have num_utter less than 2
|
||||
if len(self.speaker_to_utters[speaker]) > 1:
|
||||
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
||||
else:
|
||||
if speaker in self.speakers:
|
||||
self.speakers.remove(speaker)
|
||||
|
||||
speaker, _ = self.__sample_speaker()
|
||||
continue
|
||||
|
||||
wav = self.load_wav(utter)
|
||||
if wav.shape[0] - self.seq_len > 0:
|
||||
break
|
||||
|
||||
if utter in self.speaker_to_utters[speaker]:
|
||||
self.speaker_to_utters[speaker].remove(utter)
|
||||
|
||||
if self.augmentator is not None and self.data_augmentation_p:
|
||||
if random.random() < self.data_augmentation_p:
|
||||
wav = self.augmentator.apply_one(wav)
|
||||
|
||||
wavs.append(wav)
|
||||
labels.append(self.speakerid_to_classid[speaker])
|
||||
return wavs, labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
speaker, _ = self.__sample_speaker()
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
return speaker, speaker_id
|
||||
|
||||
def __load_from_disk_and_storage(self, speaker):
|
||||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
||||
# put the newly loaded item into storage
|
||||
self.storage.append((wavs_, labels_))
|
||||
return wavs_, labels_
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# get the batch speaker_ids
|
||||
batch = np.array(batch)
|
||||
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
|
||||
|
||||
labels = []
|
||||
feats = []
|
||||
speakers = set()
|
||||
|
||||
for speaker, speaker_id in batch:
|
||||
speaker_id = int(speaker_id)
|
||||
|
||||
# ensure that an speaker appears only once in the batch
|
||||
if speaker_id in speakers:
|
||||
|
||||
# remove current speaker
|
||||
if speaker_id in speakers_id_in_batch:
|
||||
speakers_id_in_batch.remove(speaker_id)
|
||||
|
||||
speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch)
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
speakers_id_in_batch.add(speaker_id)
|
||||
|
||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
||||
# sample from storage (if full)
|
||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||
|
||||
# force choose the current speaker or other not in batch
|
||||
# It's necessary for ideal training with AngleProto and GE2E losses
|
||||
if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id:
|
||||
attempts = 0
|
||||
while True:
|
||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||
if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch:
|
||||
break
|
||||
|
||||
attempts += 1
|
||||
# Try 5 times after that load from disk
|
||||
if attempts >= 5:
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
|
||||
break
|
||||
else:
|
||||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
|
||||
|
||||
# append speaker for control
|
||||
speakers.add(labels_[0])
|
||||
|
||||
# remove current speaker and append other
|
||||
if speaker_id in speakers_id_in_batch:
|
||||
speakers_id_in_batch.remove(speaker_id)
|
||||
|
||||
speakers_id_in_batch.add(labels_[0])
|
||||
|
||||
# get a random subset of each of the wavs and extract mel spectrograms.
|
||||
feats_ = []
|
||||
for wav in wavs_:
|
||||
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
||||
wav = wav[offset : offset + self.seq_len]
|
||||
# add random gaussian noise
|
||||
if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]:
|
||||
if random.random() < self.gaussian_augmentation_config["p"]:
|
||||
wav += np.random.normal(
|
||||
self.gaussian_augmentation_config["min_amplitude"],
|
||||
self.gaussian_augmentation_config["max_amplitude"],
|
||||
size=len(wav),
|
||||
)
|
||||
|
||||
if not self.use_torch_spec:
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
feats_.append(torch.FloatTensor(mel))
|
||||
else:
|
||||
feats_.append(torch.FloatTensor(wav))
|
||||
|
||||
labels.append(torch.LongTensor(labels_))
|
||||
feats.extend(feats_)
|
||||
|
||||
feats = torch.stack(feats)
|
||||
labels = torch.stack(labels)
|
||||
|
||||
return feats, labels
|
|
@ -1,189 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
||||
from TTS.speaker_encoder.models.resnet import PreEmphasis
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
class LSTMWithProjection(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, proj_size):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.proj_size = proj_size
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
||||
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
self.lstm.flatten_parameters()
|
||||
o, (_, _) = self.lstm(x)
|
||||
return self.linear(o)
|
||||
|
||||
|
||||
class LSTMWithoutProjection(nn.Module):
|
||||
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
|
||||
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
_, (hidden, _) = self.lstm(x)
|
||||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class LSTMSpeakerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
proj_dim=256,
|
||||
lstm_dim=768,
|
||||
num_lstm_layers=3,
|
||||
use_lstm_with_projection=True,
|
||||
use_torch_spec=False,
|
||||
audio_config=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.audio_config = audio_config
|
||||
self.proj_dim = proj_dim
|
||||
|
||||
layers = []
|
||||
# choise LSTM layer
|
||||
if use_lstm_with_projection:
|
||||
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
|
||||
for _ in range(num_lstm_layers - 1):
|
||||
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
else:
|
||||
self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
if self.use_torch_spec:
|
||||
self.torch_spec = torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
# n_fft=audio_config["fft_size"],
|
||||
# hop_length=audio_config["hop_length"],
|
||||
# win_length=audio_config["win_length"],
|
||||
# sample_rate=audio_config["sample_rate"],
|
||||
# window="hamming_window",
|
||||
# mel_fmin=0.0,
|
||||
# mel_fmax=None,
|
||||
# use_htk=True,
|
||||
# do_amp_to_db=False,
|
||||
# n_mels=audio_config["num_mels"],
|
||||
# power=2.0,
|
||||
# use_mel=True,
|
||||
# mel_norm=None,
|
||||
# )
|
||||
torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=audio_config["sample_rate"],
|
||||
n_fft=audio_config["fft_size"],
|
||||
win_length=audio_config["win_length"],
|
||||
hop_length=audio_config["hop_length"],
|
||||
window_fn=torch.hamming_window,
|
||||
n_mels=audio_config["num_mels"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.torch_spec = None
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for name, param in self.layers.named_parameters():
|
||||
if "bias" in name:
|
||||
nn.init.constant_(param, 0.0)
|
||||
elif "weight" in name:
|
||||
nn.init.xavier_normal_(param)
|
||||
|
||||
def forward(self, x, l2_norm=True):
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
|
||||
to compute the spectrogram on-the-fly.
|
||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.use_torch_spec:
|
||||
x.squeeze_(1)
|
||||
x = self.torch_spec(x)
|
||||
x = self.instancenorm(x).transpose(1, 2)
|
||||
d = self.layers(x)
|
||||
if self.use_lstm_with_projection:
|
||||
d = d[:, -1]
|
||||
if l2_norm:
|
||||
d = torch.nn.functional.normalize(d, p=2, dim=1)
|
||||
return d
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=True):
|
||||
d = self.forward(x, l2_norm=l2_norm)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: 1xTxD
|
||||
"""
|
||||
max_len = x.shape[1]
|
||||
|
||||
if max_len < num_frames:
|
||||
num_frames = max_len
|
||||
|
||||
offsets = np.linspace(0, max_len - num_frames, num=num_eval)
|
||||
|
||||
frames_batch = []
|
||||
for offset in offsets:
|
||||
offset = int(offset)
|
||||
end_offset = int(offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
frames_batch.append(frames)
|
||||
|
||||
frames_batch = torch.cat(frames_batch, dim=0)
|
||||
embeddings = self.inference(frames_batch)
|
||||
|
||||
if return_mean:
|
||||
embeddings = torch.mean(embeddings, dim=0, keepdim=True)
|
||||
|
||||
return embeddings
|
||||
|
||||
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
|
||||
"""
|
||||
Generate embeddings for a batch of utterances
|
||||
x: BxTxD
|
||||
"""
|
||||
num_overlap = num_frames * overlap
|
||||
max_len = x.shape[1]
|
||||
embed = None
|
||||
num_iters = seq_lens / (num_frames - num_overlap)
|
||||
cur_iter = 0
|
||||
for offset in range(0, max_len, num_frames - num_overlap):
|
||||
cur_iter += 1
|
||||
end_offset = min(x.shape[1], offset + num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
if embed is None:
|
||||
embed = self.inference(frames)
|
||||
else:
|
||||
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
|
||||
return embed / num_iters
|
||||
|
||||
# pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
Binary file not shown.
Before Width: | Height: | Size: 24 KiB |
|
@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = None
|
||||
optimizer: str = "radam"
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = ""
|
||||
|
|
|
@ -441,6 +441,26 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
|
||||
def emotion(root_path, meta_file, ignored_speakers=None):
|
||||
"""Generic emotion dataset"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("file_path"):
|
||||
continue
|
||||
cols = line.split(",")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
speaker_id = cols[1]
|
||||
emotion_id = cols[2].replace("\n", "")
|
||||
# ignore speakers
|
||||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
||||
return items
|
||||
|
||||
|
||||
def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument
|
||||
"""Normalizes the Baker meta data file to TTS format
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default, load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -269,7 +269,7 @@ class SpeakerManager:
|
|||
"""
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio)
|
||||
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list:
|
||||
|
|
|
@ -3,9 +3,9 @@ import unittest
|
|||
import torch as T
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
file_path = get_tests_input_path()
|
||||
|
||||
|
|
|
@ -4,14 +4,14 @@ import shutil
|
|||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
|
||||
from TTS.encoder.configs.speaker_encoder_config import SpeakerEncoderConfig
|
||||
|
||||
|
||||
def run_test_train():
|
||||
command = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.name ljspeech "
|
||||
"--coqpit.datasets.0.name ljspeech_test "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
|
@ -24,17 +24,21 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
|||
|
||||
config = SpeakerEncoderConfig(
|
||||
batch_size=4,
|
||||
num_speakers_in_batch=1,
|
||||
num_utters_per_speaker=10,
|
||||
num_loader_workers=0,
|
||||
max_train_step=2,
|
||||
num_classes_in_batch=4,
|
||||
num_utter_per_class=2,
|
||||
eval_num_classes_in_batch=4,
|
||||
eval_num_utter_per_class=2,
|
||||
num_loader_workers=1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
save_step=1,
|
||||
save_step=2,
|
||||
print_eval=True,
|
||||
run_eval=True,
|
||||
audio=BaseAudioConfig(num_mels=80),
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
config.loss = "ge2e"
|
||||
config.save_json(config_path)
|
||||
|
||||
print(config)
|
||||
|
@ -69,14 +73,14 @@ run_cli(command_train)
|
|||
shutil.rmtree(continue_path)
|
||||
|
||||
# test model with ge2e loss function
|
||||
config.loss = "ge2e"
|
||||
config.save_json(config_path)
|
||||
run_test_train()
|
||||
# config.loss = "ge2e"
|
||||
# config.save_json(config_path)
|
||||
# run_test_train()
|
||||
|
||||
# test model with angleproto loss function
|
||||
config.loss = "angleproto"
|
||||
config.save_json(config_path)
|
||||
run_test_train()
|
||||
# config.loss = "angleproto"
|
||||
# config.save_json(config_path)
|
||||
# run_test_train()
|
||||
|
||||
# test model with softmaxproto loss function
|
||||
config.loss = "softmaxproto"
|
||||
|
|
|
@ -6,8 +6,8 @@ import torch
|
|||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.speaker_encoder.utils.io import save_checkpoint
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
@ -82,3 +83,51 @@ class TestSamplers(unittest.TestCase):
|
|||
spk2 += 1
|
||||
|
||||
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
||||
|
||||
def test_perfect_sampler(self): # pylint: disable=no-self-use
|
||||
classes = set()
|
||||
for item in train_samples:
|
||||
classes.add(item["speaker_name"])
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
train_samples,
|
||||
classes,
|
||||
batch_size=2 * 3, # total batch size
|
||||
num_classes_in_batch=2,
|
||||
label_key="speaker_name",
|
||||
shuffle=False,
|
||||
drop_last=True)
|
||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||
for batch in batchs:
|
||||
spk1, spk2 = 0, 0
|
||||
# for in each batch
|
||||
for index in batch:
|
||||
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||
spk1 += 1
|
||||
else:
|
||||
spk2 += 1
|
||||
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||
|
||||
def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
|
||||
classes = set()
|
||||
for item in train_samples:
|
||||
classes.add(item["speaker_name"])
|
||||
|
||||
sampler = PerfectBatchSampler(
|
||||
train_samples,
|
||||
classes,
|
||||
batch_size=2 * 3, # total batch size
|
||||
num_classes_in_batch=2,
|
||||
label_key="speaker_name",
|
||||
shuffle=True,
|
||||
drop_last=False)
|
||||
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
||||
for batch in batchs:
|
||||
spk1, spk2 = 0, 0
|
||||
# for in each batch
|
||||
for index in batch:
|
||||
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
||||
spk1 += 1
|
||||
else:
|
||||
spk2 += 1
|
||||
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
||||
|
|
|
@ -66,8 +66,8 @@
|
|||
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":1,
|
||||
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size": 8,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"data_dep_init_iter": 1,
|
||||
|
|
|
@ -36,8 +36,8 @@
|
|||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utters_per_speaker": 10, //
|
||||
"num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utter_per_class": 10, //
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
|
|
|
@ -61,8 +61,8 @@
|
|||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":1,
|
||||
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size": 8,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
|
|
@ -61,8 +61,8 @@
|
|||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":1,
|
||||
"batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size": 8,
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
|
|
@ -7,7 +7,7 @@ from trainer.logging.tensorboard_logger import TensorboardLogger
|
|||
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
|
Loading…
Reference in New Issue