diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..4032ec6b --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +.git/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 663cfdb5..43f2e9e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,17 @@ -FROM nvidia/cuda:9.0-base-ubuntu16.04 as base +FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime WORKDIR /srv/app RUN apt-get update && \ - apt-get install -y git software-properties-common wget vim build-essential libsndfile1 && \ - add-apt-repository ppa:deadsnakes/ppa && \ - apt-get update && \ - apt-get install -y python3.6 python3.6-dev python3.6-tk && \ - # Install pip manually - wget https://bootstrap.pypa.io/get-pip.py && \ - python3.6 get-pip.py && \ - rm get-pip.py && \ - # Used by the server in server/synthesizer.py - pip install soundfile + apt-get install -y libsndfile1 espeak && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -ADD . /srv/app +# Copy Source later to enable dependency caching +COPY requirements.txt /srv/app/ +RUN pip install -r requirements.txt -# Setup for development -RUN python3.6 setup.py develop +COPY . /srv/app # http://bugs.python.org/issue19846 # > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK. diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 3194e81f..67553ec2 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -37,6 +37,8 @@ class MyDataset(Dataset): ap (TTS.utils.AudioProcessor): audio processor object. preprocessor (dataset.preprocess.Class): preprocessor for the dataset. Create your own if you need to run a new dataset. + speaker_id_cache_path (str): path where the speaker name to id + mapping is stored batch_group_size (int): (0) range of batch randomization after sorting sequences by length. min_seq_len (int): (0) minimum sequence length to be processed @@ -105,7 +107,7 @@ class MyDataset(Dataset): return text def load_data(self, idx): - text, wav_file = self.items[idx] + text, wav_file, speaker_name = self.items[idx] wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) if self.use_phonemes: @@ -120,7 +122,8 @@ class MyDataset(Dataset): sample = { 'text': text, 'wav': wav, - 'item_idx': self.items[idx][1] + 'item_idx': self.items[idx][1], + 'speaker_name': speaker_name } return sample @@ -182,6 +185,8 @@ class MyDataset(Dataset): batch[idx]['item_idx'] for idx in ids_sorted_decreasing ] text = [batch[idx]['text'] for idx in ids_sorted_decreasing] + speaker_name = [batch[idx]['speaker_name'] + for idx in ids_sorted_decreasing] mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav] @@ -219,7 +224,8 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs + return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ + stop_targets, item_idxs raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/datasets/preprocess.py b/datasets/preprocess.py index d7605fa0..3462093e 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -1,5 +1,6 @@ import os from glob import glob +import re def tweb(root_path, meta_file): @@ -8,12 +9,13 @@ def tweb(root_path, meta_file): """ txt_file = os.path.join(root_path, meta_file) items = [] + speaker_name = "tweb" with open(txt_file, 'r') as ttf: for line in ttf: cols = line.split('\t') wav_file = os.path.join(root_path, cols[0] + '.wav') text = cols[1] - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items @@ -34,6 +36,7 @@ def mozilla_old(root_path, meta_file): """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] + speaker_name = "mozilla_old" with open(txt_file, 'r') as ttf: for line in ttf: cols = line.split('|') @@ -41,7 +44,7 @@ def mozilla_old(root_path, meta_file): wav_folder = "batch{}".format(batch_no) wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip()) text = cols[0].strip() - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items @@ -49,27 +52,31 @@ def mozilla(root_path, meta_file): """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] + speaker_name = "mozilla" with open(txt_file, 'r') as ttf: for line in ttf: cols = line.split('|') wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items def mailabs(root_path, meta_files): """Normalizes M-AI-Labs meta data files to TTS format""" + speaker_regex = re.compile("by_book/(male|female|mix)/(?P[^/]+)/") if meta_files is None: csv_files = glob(root_path+"/**/metadata.csv", recursive=True) folders = [os.path.dirname(f) for f in csv_files] else: csv_files = meta_files - folders = [f.strip().split("by_book")[1][1:] for f in csv_file] + folders = [f.strip().split("by_book")[1][1:] for f in csv_files] # meta_files = [f.strip() for f in meta_files.split(",")] items = [] for idx, csv_file in enumerate(csv_files): + # determine speaker based on folder structure... + speaker_name = speaker_regex.search(csv_file).group("speaker_name") print(" | > {}".format(csv_file)) folder = folders[idx] txt_file = os.path.join(root_path, csv_file) @@ -82,7 +89,7 @@ def mailabs(root_path, meta_files): wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav') if os.path.isfile(wav_file): text = cols[1].strip() - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) else: raise RuntimeError("> File %s is not exist!"%(wav_file)) return items @@ -92,12 +99,13 @@ def ljspeech(root_path, meta_file): """Normalizes the Nancy meta data file to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] + speaker_name = "ljspeech" with open(txt_file, 'r') as ttf: for line in ttf: cols = line.split('|') wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav') text = cols[1] - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items @@ -105,12 +113,13 @@ def nancy(root_path, meta_file): """Normalizes the Nancy meta data file to TTS format""" txt_file = os.path.join(root_path, meta_file) items = [] + speaker_name = "nancy" with open(txt_file, 'r') as ttf: for line in ttf: id = line.split()[1] text = line[line.find('"') + 1:line.rfind('"') - 1] wav_file = os.path.join(root_path, "wavn", id + ".wav") - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items @@ -124,6 +133,7 @@ def common_voice(root_path, meta_file): continue cols = line.split("\t") text = cols[2] + speaker_name = cols[0] wav_file = os.path.join(root_path, "clips", cols[1] + ".wav") - items.append([text, wav_file]) + items.append([text, wav_file, speaker_name]) return items diff --git a/models/tacotron.py b/models/tacotron.py index 5d2af992..bac73ff7 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -9,6 +9,7 @@ from utils.generic_utils import sequence_mask class Tacotron(nn.Module): def __init__(self, num_chars, + num_speakers, r=5, linear_dim=1025, mel_dim=80, @@ -28,6 +29,9 @@ class Tacotron(nn.Module): self.linear_dim = linear_dim self.embedding = nn.Embedding(num_chars, 256) self.embedding.weight.data.normal_(0, 0.3) + self.speaker_embedding = nn.Embedding(num_speakers, + 256) + self.speaker_embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(256) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, @@ -38,11 +42,18 @@ class Tacotron(nn.Module): nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Sigmoid()) - def forward(self, characters, text_lengths, mel_specs): + def forward(self, characters, speaker_ids, text_lengths, mel_specs): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + encoder_outputs += speaker_embeddings mel_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, mask) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) @@ -50,10 +61,17 @@ class Tacotron(nn.Module): linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - def inference(self, characters): + def inference(self, characters, speaker_ids): B = characters.size(0) inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + encoder_outputs += speaker_embeddings mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) diff --git a/models/tacotron2.py b/models/tacotron2.py index c306a174..ba565e6f 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -11,6 +11,7 @@ from utils.generic_utils import sequence_mask class Tacotron2(nn.Module): def __init__(self, num_chars, + num_speakers, r, attn_win=False, attn_norm="softmax", @@ -28,6 +29,8 @@ class Tacotron2(nn.Module): std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) + self.speaker_embedding = nn.Embedding(num_speakers, 512) + self.speaker_embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(512) self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, @@ -40,11 +43,19 @@ class Tacotron2(nn.Module): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None): + def forward(self, text, speaker_ids, text_lengths, mel_specs=None): # compute mask for padding mask = sequence_mask(text_lengths).to(text.device) embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, text_lengths) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + + encoder_outputs += speaker_embeddings mel_outputs, stop_tokens, alignments = self.decoder( encoder_outputs, mel_specs, mask) mel_outputs_postnet = self.postnet(mel_outputs) @@ -53,9 +64,16 @@ class Tacotron2(nn.Module): mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - def inference(self, text): + def inference(self, text, speaker_ids): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + encoder_outputs += speaker_embeddings mel_outputs, stop_tokens, alignments = self.decoder.inference( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) @@ -64,12 +82,19 @@ class Tacotron2(nn.Module): mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - def inference_truncated(self, text): + def inference_truncated(self, text, speaker_ids): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + encoder_outputs += speaker_embeddings mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) diff --git a/models/tacotrongst.py b/models/tacotrongst.py index 1a77cd53..8a75a5fa 100644 --- a/models/tacotrongst.py +++ b/models/tacotrongst.py @@ -10,6 +10,7 @@ from utils.generic_utils import sequence_mask class TacotronGST(nn.Module): def __init__(self, num_chars, + num_speakers, r=5, linear_dim=1025, mel_dim=80, @@ -29,6 +30,8 @@ class TacotronGST(nn.Module): self.linear_dim = linear_dim self.embedding = nn.Embedding(num_chars, 256) self.embedding.weight.data.normal_(0, 0.3) + self.speaker_embedding = nn.Embedding(num_speakers, 256) + self.speaker_embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(256) self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, @@ -40,14 +43,22 @@ class TacotronGST(nn.Module): nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Sigmoid()) - def forward(self, characters, text_lengths, mel_specs): + def forward(self, characters, speaker_ids, text_lengths, mel_specs): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) + + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) + gst_outputs = self.gst(mel_specs) gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) - encoder_outputs = encoder_outputs + gst_outputs + encoder_outputs = encoder_outputs + gst_outputs + speaker_embeddings mel_outputs, alignments, stop_tokens = self.decoder( encoder_outputs, mel_specs, mask) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) @@ -55,14 +66,21 @@ class TacotronGST(nn.Module): linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - def inference(self, characters, style_mel=None): + def inference(self, characters, speaker_ids, style_mel=None): B = characters.size(0) inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) + speaker_embeddings = self.speaker_embedding(speaker_ids) + + speaker_embeddings.unsqueeze_(1) + speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), + encoder_outputs.size(1), + -1) if style_mel is not None: gst_outputs = self.gst(style_mel) gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) encoder_outputs = encoder_outputs + gst_outputs + encoder_outputs += speaker_embeddings mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) diff --git a/notebooks/Benchmark.ipynb b/notebooks/Benchmark.ipynb index 349575eb..cde8edf7 100644 --- a/notebooks/Benchmark.ipynb +++ b/notebooks/Benchmark.ipynb @@ -79,9 +79,9 @@ "metadata": {}, "outputs": [], "source": [ - "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", + "def tts(model, text, speaker_id, CONFIG, use_cuda, ap, use_gl, figures=True):\n", " t_1 = time.time()\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, speaker_id, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " if not use_gl:\n", @@ -208,8 +208,9 @@ "source": [ "model.eval()\n", "model.decoder.max_decoder_steps = 2000\n", + "speaker_id = 0\n", "sentence = \"Bill got in the habit of asking himself “Is that thought true?” And if he wasn’t absolutely certain it was, he just let it go.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -221,7 +222,7 @@ "outputs": [], "source": [ "sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -231,7 +232,7 @@ "outputs": [], "source": [ "sentence = \"The human voice is the most perfect instrument of all.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -241,7 +242,7 @@ "outputs": [], "source": [ "sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -253,7 +254,7 @@ "outputs": [], "source": [ "sentence = \"This cake is great. It's so delicious and moist.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -270,7 +271,7 @@ "outputs": [], "source": [ "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -280,7 +281,7 @@ "outputs": [], "source": [ "sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -290,7 +291,7 @@ "outputs": [], "source": [ "sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -300,7 +301,7 @@ "outputs": [], "source": [ "sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -310,7 +311,7 @@ "outputs": [], "source": [ "sentence = \"The buses aren't the problem, they actually provide a solution.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -327,7 +328,7 @@ "outputs": [], "source": [ "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -337,7 +338,7 @@ "outputs": [], "source": [ "sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -347,7 +348,7 @@ "outputs": [], "source": [ "sentence = \" He has read the whole thing.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -357,7 +358,7 @@ "outputs": [], "source": [ "sentence = \"He reads books.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -369,7 +370,7 @@ "outputs": [], "source": [ "sentence = \"Thisss isrealy awhsome.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -381,7 +382,7 @@ "outputs": [], "source": [ "sentence = \"This is your internet browser, Firefox.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -391,7 +392,7 @@ "outputs": [], "source": [ "sentence = \"This is your internet browser Firefox.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -401,7 +402,7 @@ "outputs": [], "source": [ "sentence = \"The quick brown fox jumps over the lazy dog.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -411,7 +412,7 @@ "outputs": [], "source": [ "sentence = \"Does the quick brown fox jump over the lazy dog?\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -421,7 +422,7 @@ "outputs": [], "source": [ "sentence = \"Eren, how are you?\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -438,7 +439,7 @@ "outputs": [], "source": [ "sentence = \"Encouraged, he started with a minute a day.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -448,7 +449,7 @@ "outputs": [], "source": [ "sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -458,7 +459,7 @@ "outputs": [], "source": [ "sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -468,7 +469,7 @@ "outputs": [], "source": [ "sentence = \"If he decided to watch TV he really watched it.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -480,7 +481,7 @@ "outputs": [], "source": [ "sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { @@ -491,7 +492,7 @@ "source": [ "# for twb dataset\n", "sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n", - "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + "align, spec, stop_tokens, wav = tts(model, sentence, speaker_id, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" ] }, { diff --git a/requirements.txt b/requirements.txt index 6a7a446f..c9f074f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ flask scipy==0.19.0 tqdm git+git://github.com/bootphon/phonemizer@master +soundfile \ No newline at end of file diff --git a/train.py b/train.py index 77857946..3b6ff866 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ import argparse import importlib +import json import os import shutil import sys @@ -25,6 +26,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, save_checkpoint, sequence_mask, weight_decay, set_init_dict, copy_config_file, setup_model) from utils.logger import Logger +from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ + copy_speaker_mapping from utils.synthesis import synthesis from utils.text.symbols import phonemes, symbols from utils.visual import plot_alignment, plot_spectrogram @@ -75,6 +78,7 @@ def setup_loader(is_val=False, verbose=False): def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, epoch): data_loader = setup_loader(is_val=False, verbose=(epoch==0)) + speaker_mapping = load_speaker_mapping(OUT_PATH) model.train() epoch_time = 0 avg_postnet_loss = 0 @@ -89,13 +93,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # setup input data text_input = data[0] text_lengths = data[1] - linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None - mel_input = data[3] - mel_lengths = data[4] - stop_targets = data[5] + speaker_names = data[2] + linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) + speaker_ids = [] + for speaker_name in speaker_names: + if speaker_name not in speaker_mapping: + speaker_mapping[speaker_name] = len(speaker_mapping) + speaker_ids.append(speaker_mapping[speaker_name]) + speaker_ids = torch.LongTensor(speaker_ids) + # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) @@ -118,10 +130,11 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, mel_lengths = mel_lengths.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda(non_blocking=True) + speaker_ids = speaker_ids.cuda(non_blocking=True) # forward pass model decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input) + text_input, speaker_ids, text_lengths, mel_input) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) @@ -178,7 +191,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if args.rank == 0: avg_postnet_loss += float(postnet_loss.item()) avg_decoder_loss += float(decoder_loss.item()) - avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item()) + avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item()) avg_step_time += step_time # Plot Training Iter Stats @@ -243,12 +256,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "epoch_time": epoch_time} tb_logger.tb_train_epoch_stats(current_step, epoch_stats) if c.tb_model_param_stats: - tb_logger.tb_model_weights(model, current_step) + tb_logger.tb_model_weights(model, current_step) + + # save speaker mapping + save_speaker_mapping(OUT_PATH, speaker_mapping) return avg_postnet_loss, current_step def evaluate(model, criterion, criterion_st, ap, current_step, epoch): data_loader = setup_loader(is_val=True) + speaker_mapping = load_speaker_mapping(OUT_PATH) model.eval() epoch_time = 0 avg_postnet_loss = 0 @@ -273,10 +290,15 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): # setup input data text_input = data[0] text_lengths = data[1] - linear_input = data[2] if c.model in ["Tacotron", "TacotronGST"] else None - mel_input = data[3] - mel_lengths = data[4] - stop_targets = data[5] + speaker_names = data[2] + linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + + speaker_ids = [speaker_mapping[speaker_name] + for speaker_name in speaker_names] + speaker_ids = torch.LongTensor(speaker_ids) # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], @@ -291,10 +313,12 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): mel_lengths = mel_lengths.cuda() linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda() + speaker_ids = speaker_ids.cuda() # forward pass decoder_output, postnet_output, alignments, stop_tokens =\ - model.forward(text_input, text_lengths, mel_input) + model.forward(text_input, speaker_ids, + text_lengths, mel_input) # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) @@ -372,10 +396,11 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): test_audios = {} test_figures = {} print(" | > Synthesizing test sentences") + speaker_id = 0 for idx, test_sentence in enumerate(test_sentences): try: wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( - model, test_sentence, c, use_cuda, ap) + model, test_sentence, speaker_id, c, use_cuda, ap) file_path = os.path.join(AUDIO_PATH, str(current_step)) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, @@ -437,6 +462,9 @@ def main(args): " > Model restored from step %d" % checkpoint['step'], flush=True) start_epoch = checkpoint['epoch'] args.restore_step = checkpoint['step'] + # copying speakers.json + prev_out_path = os.path.dirname(args.restore_path) + copy_speaker_mapping(prev_out_path, OUT_PATH) else: args.restore_step = 0 diff --git a/utils/audio.py b/utils/audio.py index e14f2b7e..fb3edad3 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -236,7 +236,10 @@ class AudioProcessor(object): if self.do_trim_silence: x = self.trim_silence(x) # sr, x = io.wavfile.read(filename) - assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) + assert self.sample_rate == sr, "Expected sampling rate {} but file " \ + "{} has {}.".format(self.sample_rate, + filename, + sr) return x def encode_16bits(self, x): diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 4fecd2a0..eb574f8d 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -33,9 +33,13 @@ def load_config(config_path): def get_git_branch(): - out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") if line.startswith("*")) - return current.replace("* ", "") + try: + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") if line.startswith("*")) + current.replace("* ", "") + except subprocess.CalledProcessError: + current = "inside_docker" + return current def get_commit_hash(): @@ -46,8 +50,12 @@ def get_commit_hash(): # except: # raise RuntimeError( # " !! Commit before training to get the commit hash.") - commit = subprocess.check_output(['git', 'rev-parse', '--short', - 'HEAD']).decode().strip() + try: + commit = subprocess.check_output(['git', 'rev-parse', '--short', + 'HEAD']).decode().strip() + # Not copying .git folder into docker container + except subprocess.CalledProcessError: + commit = "0000000" print(' > Git Hash: {}'.format(commit)) return commit @@ -250,6 +258,7 @@ def setup_model(num_chars, c): if c.model.lower() in ["tacotron", "tacotrongst"]: model = MyModel( num_chars=num_chars, + num_speakers=c.num_speakers, r=c.r, linear_dim=1025, mel_dim=80, @@ -266,6 +275,7 @@ def setup_model(num_chars, c): elif c.model.lower() == "tacotron2": model = MyModel( num_chars=num_chars, + num_speakers=c.num_speakers, r=c.r, attn_win=c.windowing, attn_norm=c.attention_norm, @@ -276,4 +286,4 @@ def setup_model(num_chars, c): forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, separate_stopnet=c.separate_stopnet) - return model \ No newline at end of file + return model diff --git a/utils/speakers.py b/utils/speakers.py new file mode 100644 index 00000000..6e3460d2 --- /dev/null +++ b/utils/speakers.py @@ -0,0 +1,30 @@ +import os +import json + + +def make_speakers_json_path(out_path): + """Returns conventional speakers.json location.""" + return os.path.join(out_path, "speakers.json") + + +def load_speaker_mapping(out_path): + """Loads speaker mapping if already present.""" + try: + with open(make_speakers_json_path(out_path)) as f: + return json.load(f) + except FileNotFoundError: + return {} + + +def save_speaker_mapping(out_path, speaker_mapping): + """Saves speaker mapping if not yet present.""" + speakers_json_path = make_speakers_json_path(out_path) + with open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) + + +def copy_speaker_mapping(out_path_a, out_path_b): + """Copies a speaker mapping when restoring a model from a previous path.""" + speaker_mapping = load_speaker_mapping(out_path_a) + if speaker_mapping is not None: + save_speaker_mapping(out_path_b, speaker_mapping) diff --git a/utils/synthesis.py b/utils/synthesis.py index 6ae056b5..deb07a5c 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -70,6 +70,7 @@ def inv_spectrogram(postnet_output, ap, CONFIG): def synthesis(model, text, + speaker_id, CONFIG, use_cuda, ap, @@ -82,6 +83,7 @@ def synthesis(model, Args: model (TTS.models): model to synthesize. text (str): target text + speaker_id (int): id of speaker CONFIG (dict): config dictionary to be loaded from config.json. use_cuda (bool): enable cuda. ap (TTS.utils.audio.AudioProcessor): audio processor to process @@ -98,6 +100,9 @@ def synthesis(model, style_mel = compute_style_mel(style_wav, ap, use_cuda) # preprocess the given text inputs = text_to_seqvec(text, CONFIG, use_cuda) + speaker_id = speaker_id_var = torch.from_numpy(speaker_id).unsqueeze(0) + if use_cuda: + speaker_id.cuda() # synthesize voice decoder_output, postnet_output, alignments, stop_tokens = run_model( model, inputs, CONFIG, truncated, style_mel)