mirror of https://github.com/coqui-ai/TTS.git
multispeaker
parent
118fe61028
commit
d172a3d3d5
|
@ -0,0 +1 @@
|
||||||
|
.git/
|
22
Dockerfile
22
Dockerfile
|
@ -1,23 +1,17 @@
|
||||||
FROM nvidia/cuda:9.0-base-ubuntu16.04 as base
|
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime
|
||||||
|
|
||||||
WORKDIR /srv/app
|
WORKDIR /srv/app
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y git software-properties-common wget vim build-essential libsndfile1 && \
|
apt-get install -y libsndfile1 espeak && \
|
||||||
add-apt-repository ppa:deadsnakes/ppa && \
|
apt-get clean && \
|
||||||
apt-get update && \
|
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||||
apt-get install -y python3.6 python3.6-dev python3.6-tk && \
|
|
||||||
# Install pip manually
|
|
||||||
wget https://bootstrap.pypa.io/get-pip.py && \
|
|
||||||
python3.6 get-pip.py && \
|
|
||||||
rm get-pip.py && \
|
|
||||||
# Used by the server in server/synthesizer.py
|
|
||||||
pip install soundfile
|
|
||||||
|
|
||||||
ADD . /srv/app
|
# Copy Source later to enable dependency caching
|
||||||
|
COPY requirements.txt /srv/app/
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
# Setup for development
|
COPY . /srv/app
|
||||||
RUN python3.6 setup.py develop
|
|
||||||
|
|
||||||
# http://bugs.python.org/issue19846
|
# http://bugs.python.org/issue19846
|
||||||
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.
|
# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK.
|
||||||
|
|
|
@ -37,6 +37,8 @@ class MyDataset(Dataset):
|
||||||
ap (TTS.utils.AudioProcessor): audio processor object.
|
ap (TTS.utils.AudioProcessor): audio processor object.
|
||||||
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
|
preprocessor (dataset.preprocess.Class): preprocessor for the dataset.
|
||||||
Create your own if you need to run a new dataset.
|
Create your own if you need to run a new dataset.
|
||||||
|
speaker_id_cache_path (str): path where the speaker name to id
|
||||||
|
mapping is stored
|
||||||
batch_group_size (int): (0) range of batch randomization after sorting
|
batch_group_size (int): (0) range of batch randomization after sorting
|
||||||
sequences by length.
|
sequences by length.
|
||||||
min_seq_len (int): (0) minimum sequence length to be processed
|
min_seq_len (int): (0) minimum sequence length to be processed
|
||||||
|
@ -105,7 +107,7 @@ class MyDataset(Dataset):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
text, wav_file = self.items[idx]
|
text, wav_file, speaker_name = self.items[idx]
|
||||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||||
|
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
|
@ -120,7 +122,8 @@ class MyDataset(Dataset):
|
||||||
sample = {
|
sample = {
|
||||||
'text': text,
|
'text': text,
|
||||||
'wav': wav,
|
'wav': wav,
|
||||||
'item_idx': self.items[idx][1]
|
'item_idx': self.items[idx][1],
|
||||||
|
'speaker_name': speaker_name
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@ -182,6 +185,8 @@ class MyDataset(Dataset):
|
||||||
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
||||||
]
|
]
|
||||||
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
||||||
|
speaker_name = [batch[idx]['speaker_name']
|
||||||
|
for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
@ -219,7 +224,8 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs
|
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||||
|
stop_targets, item_idxs
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}".format(type(batch[0]))))
|
found {}".format(type(batch[0]))))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
def tweb(root_path, meta_file):
|
def tweb(root_path, meta_file):
|
||||||
|
@ -8,12 +9,13 @@ def tweb(root_path, meta_file):
|
||||||
"""
|
"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
speaker_name = "tweb"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split('\t')
|
cols = line.split('\t')
|
||||||
wav_file = os.path.join(root_path, cols[0] + '.wav')
|
wav_file = os.path.join(root_path, cols[0] + '.wav')
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +36,7 @@ def mozilla_old(root_path, meta_file):
|
||||||
"""Normalizes Mozilla meta data files to TTS format"""
|
"""Normalizes Mozilla meta data files to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
speaker_name = "mozilla_old"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split('|')
|
cols = line.split('|')
|
||||||
|
@ -41,7 +44,7 @@ def mozilla_old(root_path, meta_file):
|
||||||
wav_folder = "batch{}".format(batch_no)
|
wav_folder = "batch{}".format(batch_no)
|
||||||
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
|
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
|
||||||
text = cols[0].strip()
|
text = cols[0].strip()
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,27 +52,31 @@ def mozilla(root_path, meta_file):
|
||||||
"""Normalizes Mozilla meta data files to TTS format"""
|
"""Normalizes Mozilla meta data files to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
speaker_name = "mozilla"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split('|')
|
cols = line.split('|')
|
||||||
wav_file = cols[1].strip()
|
wav_file = cols[1].strip()
|
||||||
text = cols[0].strip()
|
text = cols[0].strip()
|
||||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def mailabs(root_path, meta_files):
|
def mailabs(root_path, meta_files):
|
||||||
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
||||||
|
speaker_regex = re.compile("by_book/(male|female|mix)/(?P<speaker_name>[^/]+)/")
|
||||||
if meta_files is None:
|
if meta_files is None:
|
||||||
csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
|
csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
|
||||||
folders = [os.path.dirname(f) for f in csv_files]
|
folders = [os.path.dirname(f) for f in csv_files]
|
||||||
else:
|
else:
|
||||||
csv_files = meta_files
|
csv_files = meta_files
|
||||||
folders = [f.strip().split("by_book")[1][1:] for f in csv_file]
|
folders = [f.strip().split("by_book")[1][1:] for f in csv_files]
|
||||||
# meta_files = [f.strip() for f in meta_files.split(",")]
|
# meta_files = [f.strip() for f in meta_files.split(",")]
|
||||||
items = []
|
items = []
|
||||||
for idx, csv_file in enumerate(csv_files):
|
for idx, csv_file in enumerate(csv_files):
|
||||||
|
# determine speaker based on folder structure...
|
||||||
|
speaker_name = speaker_regex.search(csv_file).group("speaker_name")
|
||||||
print(" | > {}".format(csv_file))
|
print(" | > {}".format(csv_file))
|
||||||
folder = folders[idx]
|
folder = folders[idx]
|
||||||
txt_file = os.path.join(root_path, csv_file)
|
txt_file = os.path.join(root_path, csv_file)
|
||||||
|
@ -82,7 +89,7 @@ def mailabs(root_path, meta_files):
|
||||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
|
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
|
||||||
if os.path.isfile(wav_file):
|
if os.path.isfile(wav_file):
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("> File %s is not exist!"%(wav_file))
|
raise RuntimeError("> File %s is not exist!"%(wav_file))
|
||||||
return items
|
return items
|
||||||
|
@ -92,12 +99,13 @@ def ljspeech(root_path, meta_file):
|
||||||
"""Normalizes the Nancy meta data file to TTS format"""
|
"""Normalizes the Nancy meta data file to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
speaker_name = "ljspeech"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
cols = line.split('|')
|
cols = line.split('|')
|
||||||
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
|
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,12 +113,13 @@ def nancy(root_path, meta_file):
|
||||||
"""Normalizes the Nancy meta data file to TTS format"""
|
"""Normalizes the Nancy meta data file to TTS format"""
|
||||||
txt_file = os.path.join(root_path, meta_file)
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
items = []
|
items = []
|
||||||
|
speaker_name = "nancy"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
id = line.split()[1]
|
id = line.split()[1]
|
||||||
text = line[line.find('"') + 1:line.rfind('"') - 1]
|
text = line[line.find('"') + 1:line.rfind('"') - 1]
|
||||||
wav_file = os.path.join(root_path, "wavn", id + ".wav")
|
wav_file = os.path.join(root_path, "wavn", id + ".wav")
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,6 +133,7 @@ def common_voice(root_path, meta_file):
|
||||||
continue
|
continue
|
||||||
cols = line.split("\t")
|
cols = line.split("\t")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
|
speaker_name = cols[0]
|
||||||
wav_file = os.path.join(root_path, "clips", cols[1] + ".wav")
|
wav_file = os.path.join(root_path, "clips", cols[1] + ".wav")
|
||||||
items.append([text, wav_file])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
|
@ -9,6 +9,7 @@ from utils.generic_utils import sequence_mask
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
num_speakers,
|
||||||
r=5,
|
r=5,
|
||||||
linear_dim=1025,
|
linear_dim=1025,
|
||||||
mel_dim=80,
|
mel_dim=80,
|
||||||
|
@ -28,6 +29,9 @@ class Tacotron(nn.Module):
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
|
self.speaker_embedding = nn.Embedding(num_speakers,
|
||||||
|
256)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
|
@ -38,11 +42,18 @@ class Tacotron(nn.Module):
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
|
||||||
def forward(self, characters, text_lengths, mel_specs):
|
def forward(self, characters, speaker_ids, text_lengths, mel_specs):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
mask = sequence_mask(text_lengths).to(characters.device)
|
mask = sequence_mask(text_lengths).to(characters.device)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask)
|
encoder_outputs, mel_specs, mask)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
@ -50,10 +61,17 @@ class Tacotron(nn.Module):
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, characters):
|
def inference(self, characters, speaker_ids):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from utils.generic_utils import sequence_mask
|
||||||
class Tacotron2(nn.Module):
|
class Tacotron2(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
num_speakers,
|
||||||
r,
|
r,
|
||||||
attn_win=False,
|
attn_win=False,
|
||||||
attn_norm="softmax",
|
attn_norm="softmax",
|
||||||
|
@ -28,6 +29,8 @@ class Tacotron2(nn.Module):
|
||||||
std = sqrt(2.0 / (num_chars + 512))
|
std = sqrt(2.0 / (num_chars + 512))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
|
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(512)
|
||||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
|
@ -40,11 +43,19 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None):
|
def forward(self, text, speaker_ids, text_lengths, mel_specs=None):
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
mask = sequence_mask(text_lengths).to(text.device)
|
mask = sequence_mask(text_lengths).to(text.device)
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder(
|
mel_outputs, stop_tokens, alignments = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask)
|
encoder_outputs, mel_specs, mask)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
@ -53,9 +64,16 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, text):
|
def inference(self, text, speaker_ids):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder.inference(
|
mel_outputs, stop_tokens, alignments = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
@ -64,12 +82,19 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_truncated(self, text):
|
def inference_truncated(self, text, speaker_ids):
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
Preserve model states for continuous inference
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
|
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from utils.generic_utils import sequence_mask
|
||||||
class TacotronGST(nn.Module):
|
class TacotronGST(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
num_speakers,
|
||||||
r=5,
|
r=5,
|
||||||
linear_dim=1025,
|
linear_dim=1025,
|
||||||
mel_dim=80,
|
mel_dim=80,
|
||||||
|
@ -29,6 +30,8 @@ class TacotronGST(nn.Module):
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
|
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256)
|
self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256)
|
||||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
||||||
|
@ -40,14 +43,22 @@ class TacotronGST(nn.Module):
|
||||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
|
||||||
def forward(self, characters, text_lengths, mel_specs):
|
def forward(self, characters, speaker_ids, text_lengths, mel_specs):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
mask = sequence_mask(text_lengths).to(characters.device)
|
mask = sequence_mask(text_lengths).to(characters.device)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
|
|
||||||
gst_outputs = self.gst(mel_specs)
|
gst_outputs = self.gst(mel_specs)
|
||||||
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
||||||
encoder_outputs = encoder_outputs + gst_outputs
|
encoder_outputs = encoder_outputs + gst_outputs + speaker_embeddings
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, mask)
|
encoder_outputs, mel_specs, mask)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
@ -55,14 +66,21 @@ class TacotronGST(nn.Module):
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, characters, style_mel=None):
|
def inference(self, characters, speaker_ids, style_mel=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||||
|
|
||||||
|
speaker_embeddings.unsqueeze_(1)
|
||||||
|
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||||
|
encoder_outputs.size(1),
|
||||||
|
-1)
|
||||||
if style_mel is not None:
|
if style_mel is not None:
|
||||||
gst_outputs = self.gst(style_mel)
|
gst_outputs = self.gst(style_mel)
|
||||||
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
||||||
encoder_outputs = encoder_outputs + gst_outputs
|
encoder_outputs = encoder_outputs + gst_outputs
|
||||||
|
encoder_outputs += speaker_embeddings
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
|
@ -79,9 +79,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
|
"def tts(model, text, speaker_id, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
|
||||||
" t_1 = time.time()\n",
|
" t_1 = time.time()\n",
|
||||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n",
|
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, speaker_id, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n",
|
||||||
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
||||||
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
||||||
" if not use_gl:\n",
|
" if not use_gl:\n",
|
||||||
|
@ -208,8 +208,9 @@
|
||||||
"source": [
|
"source": [
|
||||||
"model.eval()\n",
|
"model.eval()\n",
|
||||||
"model.decoder.max_decoder_steps = 2000\n",
|
"model.decoder.max_decoder_steps = 2000\n",
|
||||||
|
"speaker_id = 0\n",
|
||||||
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasn’t absolutely certain it was, he just let it go.\"\n",
|
"sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasn’t absolutely certain it was, he just let it go.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -221,7 +222,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
|
"sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -231,7 +232,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
|
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -241,7 +242,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -253,7 +254,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
|
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -270,7 +271,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -280,7 +281,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
|
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -290,7 +291,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
|
"sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -300,7 +301,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
|
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -310,7 +311,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
|
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -327,7 +328,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -337,7 +338,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
|
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -347,7 +348,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \" He has read the whole thing.\"\n",
|
"sentence = \" He has read the whole thing.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -357,7 +358,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"He reads books.\"\n",
|
"sentence = \"He reads books.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -369,7 +370,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Thisss isrealy awhsome.\"\n",
|
"sentence = \"Thisss isrealy awhsome.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -381,7 +382,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"This is your internet browser, Firefox.\"\n",
|
"sentence = \"This is your internet browser, Firefox.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -391,7 +392,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"This is your internet browser Firefox.\"\n",
|
"sentence = \"This is your internet browser Firefox.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -401,7 +402,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
|
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -411,7 +412,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
|
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -421,7 +422,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Eren, how are you?\"\n",
|
"sentence = \"Eren, how are you?\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -438,7 +439,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Encouraged, he started with a minute a day.\"\n",
|
"sentence = \"Encouraged, he started with a minute a day.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -448,7 +449,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
|
"sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -458,7 +459,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
|
"sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -468,7 +469,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"If he decided to watch TV he really watched it.\"\n",
|
"sentence = \"If he decided to watch TV he really watched it.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -480,7 +481,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
|
"sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -491,7 +492,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# for twb dataset\n",
|
"# for twb dataset\n",
|
||||||
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
|
"sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n",
|
||||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
"align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -10,3 +10,4 @@ flask
|
||||||
scipy==0.19.0
|
scipy==0.19.0
|
||||||
tqdm
|
tqdm
|
||||||
git+git://github.com/bootphon/phonemizer@master
|
git+git://github.com/bootphon/phonemizer@master
|
||||||
|
soundfile
|
50
train.py
50
train.py
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
@ -25,6 +26,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
save_checkpoint, sequence_mask, weight_decay,
|
save_checkpoint, sequence_mask, weight_decay,
|
||||||
set_init_dict, copy_config_file, setup_model)
|
set_init_dict, copy_config_file, setup_model)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
|
copy_speaker_mapping
|
||||||
from utils.synthesis import synthesis
|
from utils.synthesis import synthesis
|
||||||
from utils.text.symbols import phonemes, symbols
|
from utils.text.symbols import phonemes, symbols
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -75,6 +78,7 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||||
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_postnet_loss = 0
|
avg_postnet_loss = 0
|
||||||
|
@ -89,13 +93,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None
|
speaker_names = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_targets = data[5]
|
mel_lengths = data[5]
|
||||||
|
stop_targets = data[6]
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
|
speaker_ids = []
|
||||||
|
for speaker_name in speaker_names:
|
||||||
|
if speaker_name not in speaker_mapping:
|
||||||
|
speaker_mapping[speaker_name] = len(speaker_mapping)
|
||||||
|
speaker_ids.append(speaker_mapping[speaker_name])
|
||||||
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets.size(1) // c.r, -1)
|
||||||
|
@ -118,10 +130,11 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||||
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
|
||||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||||
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
text_input, text_lengths, mel_input)
|
text_input, speaker_ids, text_lengths, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
|
@ -244,11 +257,15 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
||||||
if c.tb_model_param_stats:
|
if c.tb_model_param_stats:
|
||||||
tb_logger.tb_model_weights(model, current_step)
|
tb_logger.tb_model_weights(model, current_step)
|
||||||
|
|
||||||
|
# save speaker mapping
|
||||||
|
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||||
return avg_postnet_loss, current_step
|
return avg_postnet_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
data_loader = setup_loader(is_val=True)
|
data_loader = setup_loader(is_val=True)
|
||||||
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_postnet_loss = 0
|
avg_postnet_loss = 0
|
||||||
|
@ -273,10 +290,15 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None
|
speaker_names = data[2]
|
||||||
mel_input = data[3]
|
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
|
||||||
mel_lengths = data[4]
|
mel_input = data[4]
|
||||||
stop_targets = data[5]
|
mel_lengths = data[5]
|
||||||
|
stop_targets = data[6]
|
||||||
|
|
||||||
|
speaker_ids = [speaker_mapping[speaker_name]
|
||||||
|
for speaker_name in speaker_names]
|
||||||
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
|
@ -291,10 +313,12 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None
|
linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None
|
||||||
stop_targets = stop_targets.cuda()
|
stop_targets = stop_targets.cuda()
|
||||||
|
speaker_ids = speaker_ids.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
decoder_output, postnet_output, alignments, stop_tokens =\
|
decoder_output, postnet_output, alignments, stop_tokens =\
|
||||||
model.forward(text_input, text_lengths, mel_input)
|
model.forward(text_input, speaker_ids,
|
||||||
|
text_lengths, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
|
@ -372,10 +396,11 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
print(" | > Synthesizing test sentences")
|
print(" | > Synthesizing test sentences")
|
||||||
|
speaker_id = 0
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
try:
|
try:
|
||||||
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
||||||
model, test_sentence, c, use_cuda, ap)
|
model, test_sentence, speaker_id, c, use_cuda, ap)
|
||||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
file_path = os.path.join(file_path,
|
file_path = os.path.join(file_path,
|
||||||
|
@ -437,6 +462,9 @@ def main(args):
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['epoch']
|
start_epoch = checkpoint['epoch']
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
|
# copying speakers.json
|
||||||
|
prev_out_path = os.path.dirname(args.restore_path)
|
||||||
|
copy_speaker_mapping(prev_out_path, OUT_PATH)
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
|
|
|
@ -236,7 +236,10 @@ class AudioProcessor(object):
|
||||||
if self.do_trim_silence:
|
if self.do_trim_silence:
|
||||||
x = self.trim_silence(x)
|
x = self.trim_silence(x)
|
||||||
# sr, x = io.wavfile.read(filename)
|
# sr, x = io.wavfile.read(filename)
|
||||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
assert self.sample_rate == sr, "Expected sampling rate {} but file " \
|
||||||
|
"{} has {}.".format(self.sample_rate,
|
||||||
|
filename,
|
||||||
|
sr)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def encode_16bits(self, x):
|
def encode_16bits(self, x):
|
||||||
|
|
|
@ -33,9 +33,13 @@ def load_config(config_path):
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
|
try:
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||||
return current.replace("* ", "")
|
current.replace("* ", "")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
current = "inside_docker"
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
def get_commit_hash():
|
def get_commit_hash():
|
||||||
|
@ -46,8 +50,12 @@ def get_commit_hash():
|
||||||
# except:
|
# except:
|
||||||
# raise RuntimeError(
|
# raise RuntimeError(
|
||||||
# " !! Commit before training to get the commit hash.")
|
# " !! Commit before training to get the commit hash.")
|
||||||
|
try:
|
||||||
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||||
'HEAD']).decode().strip()
|
'HEAD']).decode().strip()
|
||||||
|
# Not copying .git folder into docker container
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
commit = "0000000"
|
||||||
print(' > Git Hash: {}'.format(commit))
|
print(' > Git Hash: {}'.format(commit))
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
@ -250,6 +258,7 @@ def setup_model(num_chars, c):
|
||||||
if c.model.lower() in ["tacotron", "tacotrongst"]:
|
if c.model.lower() in ["tacotron", "tacotrongst"]:
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
|
num_speakers=c.num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
linear_dim=1025,
|
linear_dim=1025,
|
||||||
mel_dim=80,
|
mel_dim=80,
|
||||||
|
@ -266,6 +275,7 @@ def setup_model(num_chars, c):
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
|
num_speakers=c.num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def make_speakers_json_path(out_path):
|
||||||
|
"""Returns conventional speakers.json location."""
|
||||||
|
return os.path.join(out_path, "speakers.json")
|
||||||
|
|
||||||
|
|
||||||
|
def load_speaker_mapping(out_path):
|
||||||
|
"""Loads speaker mapping if already present."""
|
||||||
|
try:
|
||||||
|
with open(make_speakers_json_path(out_path)) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
|
"""Saves speaker mapping if not yet present."""
|
||||||
|
speakers_json_path = make_speakers_json_path(out_path)
|
||||||
|
with open(speakers_json_path, "w") as f:
|
||||||
|
json.dump(speaker_mapping, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_speaker_mapping(out_path_a, out_path_b):
|
||||||
|
"""Copies a speaker mapping when restoring a model from a previous path."""
|
||||||
|
speaker_mapping = load_speaker_mapping(out_path_a)
|
||||||
|
if speaker_mapping is not None:
|
||||||
|
save_speaker_mapping(out_path_b, speaker_mapping)
|
|
@ -70,6 +70,7 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
|
|
||||||
def synthesis(model,
|
def synthesis(model,
|
||||||
text,
|
text,
|
||||||
|
speaker_id,
|
||||||
CONFIG,
|
CONFIG,
|
||||||
use_cuda,
|
use_cuda,
|
||||||
ap,
|
ap,
|
||||||
|
@ -82,6 +83,7 @@ def synthesis(model,
|
||||||
Args:
|
Args:
|
||||||
model (TTS.models): model to synthesize.
|
model (TTS.models): model to synthesize.
|
||||||
text (str): target text
|
text (str): target text
|
||||||
|
speaker_id (int): id of speaker
|
||||||
CONFIG (dict): config dictionary to be loaded from config.json.
|
CONFIG (dict): config dictionary to be loaded from config.json.
|
||||||
use_cuda (bool): enable cuda.
|
use_cuda (bool): enable cuda.
|
||||||
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
||||||
|
@ -98,6 +100,9 @@ def synthesis(model,
|
||||||
style_mel = compute_style_mel(style_wav, ap, use_cuda)
|
style_mel = compute_style_mel(style_wav, ap, use_cuda)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
inputs = text_to_seqvec(text, CONFIG, use_cuda)
|
inputs = text_to_seqvec(text, CONFIG, use_cuda)
|
||||||
|
speaker_id = speaker_id_var = torch.from_numpy(speaker_id).unsqueeze(0)
|
||||||
|
if use_cuda:
|
||||||
|
speaker_id.cuda()
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model(
|
decoder_output, postnet_output, alignments, stop_tokens = run_model(
|
||||||
model, inputs, CONFIG, truncated, style_mel)
|
model, inputs, CONFIG, truncated, style_mel)
|
||||||
|
|
Loading…
Reference in New Issue