mirror of https://github.com/coqui-ai/TTS.git
reading all speakers upfront
parent
6390c3b2e6
commit
2f2482f9b4
|
@ -76,6 +76,6 @@
|
|||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
}
|
||||
|
||||
|
|
|
@ -77,6 +77,6 @@
|
|||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
}
|
||||
|
|
@ -79,6 +79,6 @@
|
|||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
}
|
||||
|
||||
|
|
|
@ -77,6 +77,6 @@
|
|||
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
}
|
||||
|
|
@ -77,6 +77,6 @@
|
|||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
}
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
import os
|
||||
from glob import glob
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
def get_preprocessor_by_name(name):
|
||||
"""Returns the respective preprocessing function."""
|
||||
thismodule = sys.modules[__name__]
|
||||
return getattr(thismodule, name.lower())
|
||||
|
||||
|
||||
def tweb(root_path, meta_file):
|
||||
|
|
64
train.py
64
train.py
|
@ -1,8 +1,5 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
@ -27,10 +24,11 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
|||
set_init_dict, copy_config_file, setup_model)
|
||||
from utils.logger import Logger
|
||||
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
copy_speaker_mapping
|
||||
get_speakers
|
||||
from utils.synthesis import synthesis
|
||||
from utils.text.symbols import phonemes, symbols
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.preprocess import get_preprocessor_by_name
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
@ -51,7 +49,7 @@ def setup_loader(is_val=False, verbose=False):
|
|||
c.meta_file_val if is_val else c.meta_file_train,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
preprocessor=preprocessor,
|
||||
preprocessor=get_preprocessor_by_name(c.dataset),
|
||||
ap=ap,
|
||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=0 if is_val else c.min_seq_len,
|
||||
|
@ -78,7 +76,7 @@ def setup_loader(is_val=False, verbose=False):
|
|||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||
ap, epoch):
|
||||
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||
if c.num_speakers > 1:
|
||||
if c.use_speaker_embedding:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -102,21 +100,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.num_speakers > 1:
|
||||
speaker_ids = []
|
||||
for speaker_name in speaker_names:
|
||||
if speaker_name not in speaker_mapping:
|
||||
speaker_mapping[speaker_name] = len(speaker_mapping)
|
||||
speaker_ids.append(speaker_mapping[speaker_name])
|
||||
if c.use_speaker_embedding:
|
||||
speaker_ids = [speaker_mapping[speaker_name]
|
||||
for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
||||
if len(speaker_mapping) > c.num_speakers:
|
||||
raise ValueError("It seems there are at least {} speakers in "
|
||||
"your dataset, while 'num_speakers' is set to "
|
||||
"{}. Found the following speakers: {}".format(
|
||||
len(speaker_mapping),
|
||||
c.num_speakers,
|
||||
list(speaker_mapping)))
|
||||
else:
|
||||
speaker_ids = None
|
||||
|
||||
|
@ -271,15 +258,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, current_step)
|
||||
|
||||
# save speaker mapping
|
||||
if c.num_speakers > 1:
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
return avg_postnet_loss, current_step
|
||||
|
||||
|
||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||
data_loader = setup_loader(is_val=True)
|
||||
if c.num_speakers > 1:
|
||||
if c.use_speaker_embedding:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
|
@ -311,7 +295,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
|
||||
if c.num_speakers > 1:
|
||||
if c.use_speaker_embedding:
|
||||
speaker_ids = [speaker_mapping[speaker_name]
|
||||
for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
@ -443,7 +427,27 @@ def main(args):
|
|||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, c)
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)
|
||||
if args.restore_path:
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
assert all([speaker in speaker_mapping
|
||||
for speaker in speakers]), "As of now you, you cannot " \
|
||||
"introduce new speakers to " \
|
||||
"a previously trained model."
|
||||
else:
|
||||
speaker_mapping = {name: i
|
||||
for i, name in enumerate(speakers)}
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print("Training with {} speakers: {}".format(num_speakers,
|
||||
", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
|
@ -482,10 +486,6 @@ def main(args):
|
|||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
start_epoch = checkpoint['epoch']
|
||||
args.restore_step = checkpoint['step']
|
||||
# copying speakers.json
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
if c.num_speakers > 1:
|
||||
copy_speaker_mapping(prev_out_path, OUT_PATH)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -607,10 +607,6 @@ if __name__ == '__main__':
|
|||
LOG_DIR = OUT_PATH
|
||||
tb_logger = Logger(LOG_DIR)
|
||||
|
||||
# Conditional imports
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
|
|
|
@ -251,14 +251,14 @@ def set_init_dict(model_dict, checkpoint, c):
|
|||
return model_dict
|
||||
|
||||
|
||||
def setup_model(num_chars, c):
|
||||
def setup_model(num_chars, num_speakers, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in ["tacotron", "tacotrongst"]:
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
num_speakers=c.num_speakers,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
from datasets.preprocess import get_preprocessor_by_name
|
||||
|
||||
|
||||
def make_speakers_json_path(out_path):
|
||||
"""Returns conventional speakers.json location."""
|
||||
|
@ -23,8 +25,9 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
|||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def copy_speaker_mapping(out_path_a, out_path_b):
|
||||
"""Copies a speaker mapping when restoring a model from a previous path."""
|
||||
speaker_mapping = load_speaker_mapping(out_path_a)
|
||||
if speaker_mapping is not None:
|
||||
save_speaker_mapping(out_path_b, speaker_mapping)
|
||||
def get_speakers(data_root, meta_file, dataset_type):
|
||||
"""Returns a sorted, unique list of speakers in a given dataset."""
|
||||
preprocessor = get_preprocessor_by_name(dataset_type)
|
||||
items = preprocessor(data_root, meta_file)
|
||||
speakers = {e[2] for e in items}
|
||||
return sorted(speakers)
|
||||
|
|
Loading…
Reference in New Issue