multispeaker

pull/10/head
Thomas Werkmeister 2019-06-26 12:59:14 +02:00
parent 118fe61028
commit d172a3d3d5
14 changed files with 231 additions and 81 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
.git/

View File

@ -1,23 +1,17 @@
FROM nvidia/cuda:9.0-base-ubuntu16.04 as base FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime
WORKDIR /srv/app WORKDIR /srv/app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y git software-properties-common wget vim build-essential libsndfile1 && \ apt-get install -y libsndfile1 espeak && \
add-apt-repository ppa:deadsnakes/ppa && \ apt-get clean && \
apt-get update && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
apt-get install -y python3.6 python3.6-dev python3.6-tk && \
# Install pip manually
wget https://bootstrap.pypa.io/get-pip.py && \
python3.6 get-pip.py && \
rm get-pip.py && \
# Used by the server in server/synthesizer.py
pip install soundfile
ADD . /srv/app # Copy Source later to enable dependency caching
COPY requirements.txt /srv/app/
RUN pip install -r requirements.txt
# Setup for development COPY . /srv/app
RUN python3.6 setup.py develop
# http://bugs.python.org/issue19846 # http://bugs.python.org/issue19846
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK. # > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.

View File

@ -37,6 +37,8 @@ class MyDataset(Dataset):
ap (TTS.utils.AudioProcessor): audio processor object. ap (TTS.utils.AudioProcessor): audio processor object.
preprocessor (dataset.preprocess.Class): preprocessor for the dataset. preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
Create your own if you need to run a new dataset. Create your own if you need to run a new dataset.
speaker_id_cache_path (str): path where the speaker name to id
mapping is stored
batch_group_size (int): (0) range of batch randomization after sorting batch_group_size (int): (0) range of batch randomization after sorting
sequences by length. sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed min_seq_len (int): (0) minimum sequence length to be processed
@ -105,7 +107,7 @@ class MyDataset(Dataset):
return text return text
def load_data(self, idx): def load_data(self, idx):
text, wav_file = self.items[idx] text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
if self.use_phonemes: if self.use_phonemes:
@ -120,7 +122,8 @@ class MyDataset(Dataset):
sample = { sample = {
'text': text, 'text': text,
'wav': wav, 'wav': wav,
'item_idx': self.items[idx][1] 'item_idx': self.items[idx][1],
'speaker_name': speaker_name
} }
return sample return sample
@ -182,6 +185,8 @@ class MyDataset(Dataset):
batch[idx]['item_idx'] for idx in ids_sorted_decreasing batch[idx]['item_idx'] for idx in ids_sorted_decreasing
] ]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing] text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
linear = [self.ap.spectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
@ -219,7 +224,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0])))) found {}".format(type(batch[0]))))

View File

@ -1,5 +1,6 @@
import os import os
from glob import glob from glob import glob
import re
def tweb(root_path, meta_file): def tweb(root_path, meta_file):
@ -8,12 +9,13 @@ def tweb(root_path, meta_file):
""" """
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "tweb"
with open(txt_file, 'r') as ttf: with open(txt_file, 'r') as ttf:
for line in ttf: for line in ttf:
cols = line.split('\t') cols = line.split('\t')
wav_file = os.path.join(root_path, cols[0] + '.wav') wav_file = os.path.join(root_path, cols[0] + '.wav')
text = cols[1] text = cols[1]
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items
@ -34,6 +36,7 @@ def mozilla_old(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format""" """Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "mozilla_old"
with open(txt_file, 'r') as ttf: with open(txt_file, 'r') as ttf:
for line in ttf: for line in ttf:
cols = line.split('|') cols = line.split('|')
@ -41,7 +44,7 @@ def mozilla_old(root_path, meta_file):
wav_folder = "batch{}".format(batch_no) wav_folder = "batch{}".format(batch_no)
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip()) wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
text = cols[0].strip() text = cols[0].strip()
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items
@ -49,27 +52,31 @@ def mozilla(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format""" """Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "mozilla"
with open(txt_file, 'r') as ttf: with open(txt_file, 'r') as ttf:
for line in ttf: for line in ttf:
cols = line.split('|') cols = line.split('|')
wav_file = cols[1].strip() wav_file = cols[1].strip()
text = cols[0].strip() text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file) wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items
def mailabs(root_path, meta_files): def mailabs(root_path, meta_files):
"""Normalizes M-AI-Labs meta data files to TTS format""" """Normalizes M-AI-Labs meta data files to TTS format"""
speaker_regex = re.compile("by_book/(male|female|mix)/(?P<speaker_name>[^/]+)/")
if meta_files is None: if meta_files is None:
csv_files = glob(root_path+"/**/metadata.csv", recursive=True) csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
folders = [os.path.dirname(f) for f in csv_files] folders = [os.path.dirname(f) for f in csv_files]
else: else:
csv_files = meta_files csv_files = meta_files
folders = [f.strip().split("by_book")[1][1:] for f in csv_file] folders = [f.strip().split("by_book")[1][1:] for f in csv_files]
# meta_files = [f.strip() for f in meta_files.split(",")] # meta_files = [f.strip() for f in meta_files.split(",")]
items = [] items = []
for idx, csv_file in enumerate(csv_files): for idx, csv_file in enumerate(csv_files):
# determine speaker based on folder structure...
speaker_name = speaker_regex.search(csv_file).group("speaker_name")
print(" | > {}".format(csv_file)) print(" | > {}".format(csv_file))
folder = folders[idx] folder = folders[idx]
txt_file = os.path.join(root_path, csv_file) txt_file = os.path.join(root_path, csv_file)
@ -82,7 +89,7 @@ def mailabs(root_path, meta_files):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav') wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
if os.path.isfile(wav_file): if os.path.isfile(wav_file):
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
else: else:
raise RuntimeError("> File %s is not exist!"%(wav_file)) raise RuntimeError("> File %s is not exist!"%(wav_file))
return items return items
@ -92,12 +99,13 @@ def ljspeech(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format""" """Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "ljspeech"
with open(txt_file, 'r') as ttf: with open(txt_file, 'r') as ttf:
for line in ttf: for line in ttf:
cols = line.split('|') cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav') wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
text = cols[1] text = cols[1]
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items
@ -105,12 +113,13 @@ def nancy(root_path, meta_file):
"""Normalizes the Nancy meta data file to TTS format""" """Normalizes the Nancy meta data file to TTS format"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
items = [] items = []
speaker_name = "nancy"
with open(txt_file, 'r') as ttf: with open(txt_file, 'r') as ttf:
for line in ttf: for line in ttf:
id = line.split()[1] id = line.split()[1]
text = line[line.find('"') + 1:line.rfind('"') - 1] text = line[line.find('"') + 1:line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", id + ".wav") wav_file = os.path.join(root_path, "wavn", id + ".wav")
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items
@ -124,6 +133,7 @@ def common_voice(root_path, meta_file):
continue continue
cols = line.split("\t") cols = line.split("\t")
text = cols[2] text = cols[2]
speaker_name = cols[0]
wav_file = os.path.join(root_path, "clips", cols[1] + ".wav") wav_file = os.path.join(root_path, "clips", cols[1] + ".wav")
items.append([text, wav_file]) items.append([text, wav_file, speaker_name])
return items return items

View File

@ -9,6 +9,7 @@ from utils.generic_utils import sequence_mask
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, def __init__(self,
num_chars, num_chars,
num_speakers,
r=5, r=5,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
@ -28,6 +29,9 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256) self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.speaker_embedding = nn.Embedding(num_speakers,
256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
@ -38,11 +42,18 @@ class Tacotron(nn.Module):
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid()) nn.Sigmoid())
def forward(self, characters, text_lengths, mel_specs): def forward(self, characters, speaker_ids, text_lengths, mel_specs):
B = characters.size(0) B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device) mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters) inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
@ -50,10 +61,17 @@ class Tacotron(nn.Module):
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters): def inference(self, characters, speaker_ids):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder.inference( mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs) encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -11,6 +11,7 @@ from utils.generic_utils import sequence_mask
class Tacotron2(nn.Module): class Tacotron2(nn.Module):
def __init__(self, def __init__(self,
num_chars, num_chars,
num_speakers,
r, r,
attn_win=False, attn_win=False,
attn_norm="softmax", attn_norm="softmax",
@ -28,6 +29,8 @@ class Tacotron2(nn.Module):
std = sqrt(2.0 / (num_chars + 512)) std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val) self.embedding.weight.data.uniform_(-val, val)
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512) self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
@ -40,11 +43,19 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None): def forward(self, text, speaker_ids, text_lengths, mel_specs=None):
# compute mask for padding # compute mask for padding
mask = sequence_mask(text_lengths).to(text.device) mask = sequence_mask(text_lengths).to(text.device)
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder( mel_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
@ -53,9 +64,16 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments) mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference(self, text): def inference(self, text, speaker_ids):
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self.encoder.inference(embedded_inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder.inference( mel_outputs, stop_tokens, alignments = self.decoder.inference(
encoder_outputs) encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
@ -64,12 +82,19 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments) mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference_truncated(self, text): def inference_truncated(self, text, speaker_ids):
""" """
Preserve model states for continuous inference Preserve model states for continuous inference
""" """
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs) encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
encoder_outputs += speaker_embeddings
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated( mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
encoder_outputs) encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)

View File

@ -10,6 +10,7 @@ from utils.generic_utils import sequence_mask
class TacotronGST(nn.Module): class TacotronGST(nn.Module):
def __init__(self, def __init__(self,
num_chars, num_chars,
num_speakers,
r=5, r=5,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
@ -29,6 +30,8 @@ class TacotronGST(nn.Module):
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256) self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)
self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256) self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
@ -40,14 +43,22 @@ class TacotronGST(nn.Module):
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid()) nn.Sigmoid())
def forward(self, characters, text_lengths, mel_specs): def forward(self, characters, speaker_ids, text_lengths, mel_specs):
B = characters.size(0) B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device) mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters) inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
gst_outputs = self.gst(mel_specs) gst_outputs = self.gst(mel_specs)
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
encoder_outputs = encoder_outputs + gst_outputs encoder_outputs = encoder_outputs + gst_outputs + speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
@ -55,14 +66,21 @@ class TacotronGST(nn.Module):
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters, style_mel=None): def inference(self, characters, speaker_ids, style_mel=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
speaker_embeddings = self.speaker_embedding(speaker_ids)
speaker_embeddings.unsqueeze_(1)
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
encoder_outputs.size(1),
-1)
if style_mel is not None: if style_mel is not None:
gst_outputs = self.gst(style_mel) gst_outputs = self.gst(style_mel)
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
encoder_outputs = encoder_outputs + gst_outputs encoder_outputs = encoder_outputs + gst_outputs
encoder_outputs += speaker_embeddings
mel_outputs, alignments, stop_tokens = self.decoder.inference( mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs) encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -79,9 +79,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", "def tts(model, text, speaker_id, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
" t_1 = time.time()\n", " t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n", " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, speaker_id, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" if not use_gl:\n", " if not use_gl:\n",
@ -208,8 +208,9 @@
"source": [ "source": [
"model.eval()\n", "model.eval()\n",
"model.decoder.max_decoder_steps = 2000\n", "model.decoder.max_decoder_steps = 2000\n",
"speaker_id = 0\n",
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasnt absolutely certain it was, he just let it go.\"\n", "sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasnt absolutely certain it was, he just let it go.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -221,7 +222,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n", "sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -231,7 +232,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"The human voice is the most perfect instrument of all.\"\n", "sentence = \"The human voice is the most perfect instrument of all.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -241,7 +242,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n", "sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -253,7 +254,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"This cake is great. It's so delicious and moist.\"\n", "sentence = \"This cake is great. It's so delicious and moist.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -270,7 +271,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n", "sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -280,7 +281,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n", "sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -290,7 +291,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Heres a way to measure the acute emotional intelligence that has never gone out of style.\"\n", "sentence = \"Heres a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -300,7 +301,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n", "sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -310,7 +311,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n", "sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -327,7 +328,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n", "sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -337,7 +338,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n", "sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -347,7 +348,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \" He has read the whole thing.\"\n", "sentence = \" He has read the whole thing.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -357,7 +358,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"He reads books.\"\n", "sentence = \"He reads books.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -369,7 +370,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Thisss isrealy awhsome.\"\n", "sentence = \"Thisss isrealy awhsome.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -381,7 +382,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"This is your internet browser, Firefox.\"\n", "sentence = \"This is your internet browser, Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -391,7 +392,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"This is your internet browser Firefox.\"\n", "sentence = \"This is your internet browser Firefox.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -401,7 +402,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n", "sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -411,7 +412,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n", "sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -421,7 +422,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Eren, how are you?\"\n", "sentence = \"Eren, how are you?\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -438,7 +439,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Encouraged, he started with a minute a day.\"\n", "sentence = \"Encouraged, he started with a minute a day.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -448,7 +449,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n", "sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -458,7 +459,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n", "sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -468,7 +469,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"If he decided to watch TV he really watched it.\"\n", "sentence = \"If he decided to watch TV he really watched it.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -480,7 +481,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n", "sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {
@ -491,7 +492,7 @@
"source": [ "source": [
"# for twb dataset\n", "# for twb dataset\n",
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n", "sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
] ]
}, },
{ {

View File

@ -10,3 +10,4 @@ flask
scipy==0.19.0 scipy==0.19.0
tqdm tqdm
git+git://github.com/bootphon/phonemizer@master git+git://github.com/bootphon/phonemizer@master
soundfile

View File

@ -1,5 +1,6 @@
import argparse import argparse
import importlib import importlib
import json
import os import os
import shutil import shutil
import sys import sys
@ -25,6 +26,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
save_checkpoint, sequence_mask, weight_decay, save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file, setup_model) set_init_dict, copy_config_file, setup_model)
from utils.logger import Logger from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
copy_speaker_mapping
from utils.synthesis import synthesis from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols from utils.text.symbols import phonemes, symbols
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
@ -75,6 +78,7 @@ def setup_loader(is_val=False, verbose=False):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch): ap, epoch):
data_loader = setup_loader(is_val=False, verbose=(epoch==0)) data_loader = setup_loader(is_val=False, verbose=(epoch==0))
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.train() model.train()
epoch_time = 0 epoch_time = 0
avg_postnet_loss = 0 avg_postnet_loss = 0
@ -89,13 +93,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None speaker_names = data[2]
mel_input = data[3] linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
mel_lengths = data[4] mel_input = data[4]
stop_targets = data[5] mel_lengths = data[5]
stop_targets = data[6]
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
speaker_ids = []
for speaker_name in speaker_names:
if speaker_name not in speaker_mapping:
speaker_mapping[speaker_name] = len(speaker_mapping)
speaker_ids.append(speaker_mapping[speaker_name])
speaker_ids = torch.LongTensor(speaker_ids)
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1) stop_targets.size(1) // c.r, -1)
@ -118,10 +130,11 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
speaker_ids = speaker_ids.cuda(non_blocking=True)
# forward pass model # forward pass model
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input) text_input, speaker_ids, text_lengths, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
@ -244,11 +257,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_epoch_stats(current_step, epoch_stats) tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
if c.tb_model_param_stats: if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step) tb_logger.tb_model_weights(model, current_step)
# save speaker mapping
save_speaker_mapping(OUT_PATH, speaker_mapping)
return avg_postnet_loss, current_step return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch): def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True) data_loader = setup_loader(is_val=True)
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
avg_postnet_loss = 0 avg_postnet_loss = 0
@ -273,10 +290,15 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None speaker_names = data[2]
mel_input = data[3] linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
mel_lengths = data[4] mel_input = data[4]
stop_targets = data[5] mel_lengths = data[5]
stop_targets = data[6]
speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
@ -291,10 +313,12 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
speaker_ids = speaker_ids.cuda()
# forward pass # forward pass
decoder_output, postnet_output, alignments, stop_tokens =\ decoder_output, postnet_output, alignments, stop_tokens =\
model.forward(text_input, text_lengths, mel_input) model.forward(text_input, speaker_ids,
text_lengths, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
@ -372,10 +396,11 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
speaker_id = 0
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
model, test_sentence, c, use_cuda, ap) model, test_sentence, speaker_id, c, use_cuda, ap)
file_path = os.path.join(AUDIO_PATH, str(current_step)) file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, file_path = os.path.join(file_path,
@ -437,6 +462,9 @@ def main(args):
" > Model restored from step %d" % checkpoint['step'], flush=True) " > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
# copying speakers.json
prev_out_path = os.path.dirname(args.restore_path)
copy_speaker_mapping(prev_out_path, OUT_PATH)
else: else:
args.restore_step = 0 args.restore_step = 0

View File

@ -236,7 +236,10 @@ class AudioProcessor(object):
if self.do_trim_silence: if self.do_trim_silence:
x = self.trim_silence(x) x = self.trim_silence(x)
# sr, x = io.wavfile.read(filename) # sr, x = io.wavfile.read(filename)
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) assert self.sample_rate == sr, "Expected sampling rate {} but file " \
"{} has {}.".format(self.sample_rate,
filename,
sr)
return x return x
def encode_16bits(self, x): def encode_16bits(self, x):

View File

@ -33,9 +33,13 @@ def load_config(config_path):
def get_git_branch(): def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8") out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*")) current = next(line for line in out.split("\n") if line.startswith("*"))
return current.replace("* ", "") current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
return current
def get_commit_hash(): def get_commit_hash():
@ -46,8 +50,12 @@ def get_commit_hash():
# except: # except:
# raise RuntimeError( # raise RuntimeError(
# " !! Commit before training to get the commit hash.") # " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(['git', 'rev-parse', '--short', commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip() 'HEAD']).decode().strip()
# Not copying .git folder into docker container
except subprocess.CalledProcessError:
commit = "0000000"
print(' > Git Hash: {}'.format(commit)) print(' > Git Hash: {}'.format(commit))
return commit return commit
@ -250,6 +258,7 @@ def setup_model(num_chars, c):
if c.model.lower() in ["tacotron", "tacotrongst"]: if c.model.lower() in ["tacotron", "tacotrongst"]:
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
num_speakers=c.num_speakers,
r=c.r, r=c.r,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
@ -266,6 +275,7 @@ def setup_model(num_chars, c):
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
num_speakers=c.num_speakers,
r=c.r, r=c.r,
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,

30
utils/speakers.py Normal file
View File

@ -0,0 +1,30 @@
import os
import json
def make_speakers_json_path(out_path):
"""Returns conventional speakers.json location."""
return os.path.join(out_path, "speakers.json")
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
try:
with open(make_speakers_json_path(out_path)) as f:
return json.load(f)
except FileNotFoundError:
return {}
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
speakers_json_path = make_speakers_json_path(out_path)
with open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
def copy_speaker_mapping(out_path_a, out_path_b):
"""Copies a speaker mapping when restoring a model from a previous path."""
speaker_mapping = load_speaker_mapping(out_path_a)
if speaker_mapping is not None:
save_speaker_mapping(out_path_b, speaker_mapping)

View File

@ -70,6 +70,7 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
def synthesis(model, def synthesis(model,
text, text,
speaker_id,
CONFIG, CONFIG,
use_cuda, use_cuda,
ap, ap,
@ -82,6 +83,7 @@ def synthesis(model,
Args: Args:
model (TTS.models): model to synthesize. model (TTS.models): model to synthesize.
text (str): target text text (str): target text
speaker_id (int): id of speaker
CONFIG (dict): config dictionary to be loaded from config.json. CONFIG (dict): config dictionary to be loaded from config.json.
use_cuda (bool): enable cuda. use_cuda (bool): enable cuda.
ap (TTS.utils.audio.AudioProcessor): audio processor to process ap (TTS.utils.audio.AudioProcessor): audio processor to process
@ -98,6 +100,9 @@ def synthesis(model,
style_mel = compute_style_mel(style_wav, ap, use_cuda) style_mel = compute_style_mel(style_wav, ap, use_cuda)
# preprocess the given text # preprocess the given text
inputs = text_to_seqvec(text, CONFIG, use_cuda) inputs = text_to_seqvec(text, CONFIG, use_cuda)
speaker_id = speaker_id_var = torch.from_numpy(speaker_id).unsqueeze(0)
if use_cuda:
speaker_id.cuda()
# synthesize voice # synthesize voice
decoder_output, postnet_output, alignments, stop_tokens = run_model( decoder_output, postnet_output, alignments, stop_tokens = run_model(
model, inputs, CONFIG, truncated, style_mel) model, inputs, CONFIG, truncated, style_mel)