reading all speakers upfront

pull/10/head
Thomas Werkmeister 2019-07-10 18:38:55 +02:00
parent 6390c3b2e6
commit 2f2482f9b4
9 changed files with 52 additions and 46 deletions

View File

@ -76,6 +76,6 @@
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings "use_speaker_embedding": false // whether to use additional embeddings for separate speakers
} }

View File

@ -77,6 +77,6 @@
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings "use_speaker_embedding": false // whether to use additional embeddings for separate speakers
} }

View File

@ -79,6 +79,6 @@
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings "use_speaker_embedding": false // whether to use additional embeddings for separate speakers
} }

View File

@ -77,6 +77,6 @@
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation. "use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "phoneme_language": "de", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings "use_speaker_embedding": false // whether to use additional embeddings for separate speakers
} }

View File

@ -77,6 +77,6 @@
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"num_speakers": 10 // should just be bigger than the actual number of speakers, 0 disables speaker embeddings "use_speaker_embedding": false // whether to use additional embeddings for separate speakers
} }

View File

@ -1,6 +1,13 @@
import os import os
from glob import glob from glob import glob
import re import re
import sys
def get_preprocessor_by_name(name):
"""Returns the respective preprocessing function."""
thismodule = sys.modules[__name__]
return getattr(thismodule, name.lower())
def tweb(root_path, meta_file): def tweb(root_path, meta_file):

View File

@ -1,8 +1,5 @@
import argparse import argparse
import importlib
import json
import os import os
import shutil
import sys import sys
import time import time
import traceback import traceback
@ -27,10 +24,11 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
set_init_dict, copy_config_file, setup_model) set_init_dict, copy_config_file, setup_model)
from utils.logger import Logger from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
copy_speaker_mapping get_speakers
from utils.synthesis import synthesis from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols from utils.text.symbols import phonemes, symbols
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.preprocess import get_preprocessor_by_name
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
@ -51,7 +49,7 @@ def setup_loader(is_val=False, verbose=False):
c.meta_file_val if is_val else c.meta_file_train, c.meta_file_val if is_val else c.meta_file_train,
c.r, c.r,
c.text_cleaner, c.text_cleaner,
preprocessor=preprocessor, preprocessor=get_preprocessor_by_name(c.dataset),
ap=ap, ap=ap,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len, min_seq_len=0 if is_val else c.min_seq_len,
@ -78,7 +76,7 @@ def setup_loader(is_val=False, verbose=False):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch): ap, epoch):
data_loader = setup_loader(is_val=False, verbose=(epoch==0)) data_loader = setup_loader(is_val=False, verbose=(epoch==0))
if c.num_speakers > 1: if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH) speaker_mapping = load_speaker_mapping(OUT_PATH)
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -102,21 +100,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
if c.num_speakers > 1: if c.use_speaker_embedding:
speaker_ids = [] speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names: for speaker_name in speaker_names]
if speaker_name not in speaker_mapping:
speaker_mapping[speaker_name] = len(speaker_mapping)
speaker_ids.append(speaker_mapping[speaker_name])
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
if len(speaker_mapping) > c.num_speakers:
raise ValueError("It seems there are at least {} speakers in "
"your dataset, while 'num_speakers' is set to "
"{}. Found the following speakers: {}".format(
len(speaker_mapping),
c.num_speakers,
list(speaker_mapping)))
else: else:
speaker_ids = None speaker_ids = None
@ -271,15 +258,12 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if c.tb_model_param_stats: if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step) tb_logger.tb_model_weights(model, current_step)
# save speaker mapping
if c.num_speakers > 1:
save_speaker_mapping(OUT_PATH, speaker_mapping)
return avg_postnet_loss, current_step return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch): def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True) data_loader = setup_loader(is_val=True)
if c.num_speakers > 1: if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH) speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
@ -311,7 +295,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
if c.num_speakers > 1: if c.use_speaker_embedding:
speaker_ids = [speaker_mapping[speaker_name] speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
@ -443,7 +427,27 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"]) c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, c)
if c.use_speaker_embedding:
speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i
for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
model = setup_model(num_chars, num_speakers, c)
print(" | > Num output units : {}".format(ap.num_freq), flush=True) print(" | > Num output units : {}".format(ap.num_freq), flush=True)
@ -482,10 +486,6 @@ def main(args):
" > Model restored from step %d" % checkpoint['step'], flush=True) " > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
# copying speakers.json
prev_out_path = os.path.dirname(args.restore_path)
if c.num_speakers > 1:
copy_speaker_mapping(prev_out_path, OUT_PATH)
else: else:
args.restore_step = 0 args.restore_step = 0
@ -607,10 +607,6 @@ if __name__ == '__main__':
LOG_DIR = OUT_PATH LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR) tb_logger = Logger(LOG_DIR)
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)

View File

@ -251,14 +251,14 @@ def set_init_dict(model_dict, checkpoint, c):
return model_dict return model_dict
def setup_model(num_chars, c): def setup_model(num_chars, num_speakers, c):
print(" > Using model: {}".format(c.model)) print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.' + c.model.lower()) MyModel = importlib.import_module('models.' + c.model.lower())
MyModel = getattr(MyModel, c.model) MyModel = getattr(MyModel, c.model)
if c.model.lower() in ["tacotron", "tacotrongst"]: if c.model.lower() in ["tacotron", "tacotrongst"]:
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
num_speakers=c.num_speakers, num_speakers=num_speakers,
r=c.r, r=c.r,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,

View File

@ -1,6 +1,8 @@
import os import os
import json import json
from datasets.preprocess import get_preprocessor_by_name
def make_speakers_json_path(out_path): def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location.""" """Returns conventional speakers.json location."""
@ -23,8 +25,9 @@ def save_speaker_mapping(out_path, speaker_mapping):
json.dump(speaker_mapping, f, indent=4) json.dump(speaker_mapping, f, indent=4)
def copy_speaker_mapping(out_path_a, out_path_b): def get_speakers(data_root, meta_file, dataset_type):
"""Copies a speaker mapping when restoring a model from a previous path.""" """Returns a sorted, unique list of speakers in a given dataset."""
speaker_mapping = load_speaker_mapping(out_path_a) preprocessor = get_preprocessor_by_name(dataset_type)
if speaker_mapping is not None: items = preprocessor(data_root, meta_file)
save_speaker_mapping(out_path_b, speaker_mapping) speakers = {e[2] for e in items}
return sorted(speakers)