diff --git a/.gitignore b/.gitignore index 3d5cc11a..2dcbddbf 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,5 @@ venv.bak/ # pytorch models *.pth.tar +result/ diff --git a/datasets/.LJSpeech.py.swp b/datasets/.LJSpeech.py.swp index 68d81691..b80a29a9 100644 Binary files a/datasets/.LJSpeech.py.swp and b/datasets/.LJSpeech.py.swp differ diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 5fdb850a..7f9aca36 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -5,7 +5,7 @@ import collections from torch.utils.data import Dataset import train_config as c -from Tacotron.text import text_to_sequence +from Tacotron.utils.text import text_to_sequence from Tacotron.utils.audio import * from Tacotron.utils.data import prepare_data, pad_data, pad_per_step @@ -53,13 +53,17 @@ class LJSpeechDataset(Dataset): magnitude = np.array([spectrogram(w) for w in wav]) mel = np.array([melspectrogram(w) for w in wav]) - timesteps = mel.shape[-1] + timesteps = mel.shape[2] # PAD with zeros that can be divided by outputs per step if timesteps % self.outputs_per_step != 0: magnitude = pad_per_step(magnitude, self.outputs_per_step) mel = pad_per_step(mel, self.outputs_per_step) + # reshape jombo + magnitude = magnitude.transpose(0, 2, 1) + mel = mel.transpose(0, 2, 1) + return text, magnitude, mel raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ diff --git a/network.py b/network.py deleted file mode 100644 index 51149e42..00000000 --- a/network.py +++ /dev/null @@ -1,138 +0,0 @@ -import random -from module import * -from text.symbols import symbols - - -class Encoder(nn.Module): - """ - Encoder - """ - - def __init__(self, embedding_size, hidden_size): - """ - - :param embedding_size: dimension of embedding - """ - super(Encoder, self).__init__() - self.embedding_size = embedding_size - self.embed = nn.Embedding(len(symbols), embedding_size) - self.prenet = Prenet(embedding_size, hidden_size * 2, hidden_size) - self.cbhg = CBHG(hidden_size) - - def forward(self, input_): - - input_ = torch.transpose(self.embed(input_), 1, 2) - prenet = self.prenet.forward(input_) - memory = self.cbhg.forward(prenet) - - return memory - - -class MelDecoder(nn.Module): - """ - Decoder - """ - - def __init__(self, num_mels, hidden_size, dec_out_per_step, - teacher_forcing_ratio): - - super(MelDecoder, self).__init__() - self.prenet = Prenet(num_mels, hidden_size * 2, hidden_size) - self.attn_decoder = AttentionDecoder(hidden_size * 2, num_mels, - dec_out_per_step) - self.dec_out_per_step = dec_out_per_step - self.teacher_forcing_ratio = teacher_forcing_ratio - - def forward(self, decoder_input, memory): - - # Initialize hidden state of GRUcells - attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.inithidden( - decoder_input.size()[0]) - outputs = list() - - # Training phase - if self.training: - # Prenet - dec_input = self.prenet.forward(decoder_input) - timesteps = dec_input.size()[2] // self.dec_out_per_step - - # [GO] Frame - prev_output = dec_input[:, :, 0] - - for i in range(timesteps): - prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory, - attn_hidden=attn_hidden, - gru1_hidden=gru1_hidden, - gru2_hidden=gru2_hidden) - - outputs.append(prev_output) - - if random.random() < self.teacher_forcing_ratio: - # Get spectrum at rth position - prev_output = dec_input[:, :, i * self.dec_out_per_step] - else: - # Get last output - prev_output = prev_output[:, :, -1] - - # Concatenate all mel spectrogram - outputs = torch.cat(outputs, 2) - - else: - # [GO] Frame - prev_output = decoder_input - - for i in range(max_iters): - prev_output = self.prenet.forward(prev_output) - prev_output = prev_output[:, :, 0] - prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory, - attn_hidden=attn_hidden, - gru1_hidden=gru1_hidden, - gru2_hidden=gru2_hidden) - outputs.append(prev_output) - prev_output = prev_output[:, :, -1].unsqueeze(2) - - outputs = torch.cat(outputs, 2) - - return outputs - - -class PostProcessingNet(nn.Module): - """ - Post-processing Network - """ - - def __init__(self, num_mels, num_freq, hidden_size): - super(PostProcessingNet, self).__init__() - self.postcbhg = CBHG(hidden_size, - K=8, - projection_size=num_mels, - is_post=True) - self.linear = SeqLinear(hidden_size * 2, - num_freq) - - def forward(self, input_): - out = self.postcbhg.forward(input_) - out = self.linear.forward(torch.transpose(out, 1, 2)) - - return out - - -class Tacotron(nn.Module): - """ - End-to-end Tacotron Network - """ - - def __init__(self, embedding_size, hidden_size, num_mels, num_freq, - dec_out_per_step, teacher_forcing_ratio): - super(Tacotron, self).__init__() - self.encoder = Encoder(embedding_size, hidden_size) - self.decoder1 = MelDecoder(num_mels, hidden_size, dec_out_per_step, - teacher_forcing_ratio) - self.decoder2 = PostProcessingNet(num_mels, num_freq, hidden_size) - - def forward(self, characters, mel_input): - memory = self.encoder.forward(characters) - mel_output = self.decoder1.forward(mel_input, memory) - linear_output = self.decoder2.forward(mel_output) - - return mel_output, linear_output diff --git a/text/__init__.py b/text/__init__.py deleted file mode 100644 index 3c06bbbd..00000000 --- a/text/__init__.py +++ /dev/null @@ -1,78 +0,0 @@ -#-*- coding: utf-8 -*- - -import re -from Tacotron.text import cleaners -from Tacotron.text.symbols import symbols - - -# Mappings from symbol to numeric ID and vice versa: -_symbol_to_id = {s: i for i, s in enumerate(symbols)} -_id_to_symbol = {i: s for i, s in enumerate(symbols)} - -# Regular expression matching text enclosed in curly braces: -_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') - - -def text_to_sequence(text, cleaner_names): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through - - Returns: - List of integers corresponding to the symbols in the text - ''' - sequence = [] - - # Check for curly braces and treat their contents as ARPAbet: - while len(text): - m = _curly_re.match(text) - if not m: - sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) - break - sequence += _symbols_to_sequence( - _clean_text(m.group(1), cleaner_names)) - sequence += _arpabet_to_sequence(m.group(2)) - text = m.group(3) - - # Append EOS token - sequence.append(_symbol_to_id['~']) - return sequence - - -def sequence_to_text(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' - for symbol_id in sequence: - if symbol_id in _id_to_symbol: - s = _id_to_symbol[symbol_id] - # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] - result += s - return result.replace('}{', ' ') - - -def _clean_text(text, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text) - return text - - -def _symbols_to_sequence(symbols): - return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] - - -def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) - - -def _should_keep_symbol(s): - return s in _symbol_to_id and s is not '_' and s is not '~' diff --git a/text/cleaners.py b/text/cleaners.py deleted file mode 100644 index fe0a46a2..00000000 --- a/text/cleaners.py +++ /dev/null @@ -1,91 +0,0 @@ -#-*- coding: utf-8 -*- - - -''' -Cleaners are transformations that run over the input text at both training and eval time. - -Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" -hyperparameter. Some cleaners are English-specific. You'll typically want to use: - 1. "english_cleaners" for English text - 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using - the Unidecode library (https://pypi.python.org/pypi/Unidecode) - 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update - the symbols in symbols.py to match your data). -''' - -import re -from unidecode import unidecode -from .numbers import normalize_numbers - - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def expand_numbers(text): - return normalize_numbers(text) - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - return text diff --git a/text/cmudict.py b/text/cmudict.py deleted file mode 100644 index 6673546b..00000000 --- a/text/cmudict.py +++ /dev/null @@ -1,65 +0,0 @@ -#-*- coding: utf-8 -*- - - -import re - - -valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' -] - -_valid_symbol_set = set(valid_symbols) - - -class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' - - def __init__(self, file_or_path, keep_ambiguous=True): - if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: - entries = _parse_cmudict(f) - else: - entries = _parse_cmudict(file_or_path) - if not keep_ambiguous: - entries = {word: pron for word, - pron in entries.items() if len(pron) == 1} - self._entries = entries - - def __len__(self): - return len(self._entries) - - def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' - return self._entries.get(word.upper()) - - -_alt_re = re.compile(r'\([0-9]+\)') - - -def _parse_cmudict(file): - cmudict = {} - for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) - pronunciation = _get_pronunciation(parts[1]) - if pronunciation: - if word in cmudict: - cmudict[word].append(pronunciation) - else: - cmudict[word] = [pronunciation] - return cmudict - - -def _get_pronunciation(s): - parts = s.strip().split(' ') - for part in parts: - if part not in _valid_symbol_set: - return None - return ' '.join(parts) diff --git a/text/numbers.py b/text/numbers.py deleted file mode 100644 index 4ce2d389..00000000 --- a/text/numbers.py +++ /dev/null @@ -1,71 +0,0 @@ -#-*- coding: utf-8 -*- - -import inflect -import re - - -_inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') - - -def _remove_commas(m): - return m.group(1).replace(',', '') - - -def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: - return match + ' dollars' # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) - elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) - else: - return 'zero dollars' - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return 'two thousand' - elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' - else: - return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') - else: - return _inflect.number_to_words(num, andword='') - - -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text diff --git a/text/symbols.py b/text/symbols.py deleted file mode 100644 index e9a1a249..00000000 --- a/text/symbols.py +++ /dev/null @@ -1,24 +0,0 @@ -#-*- coding: utf-8 -*- - - -''' -Defines the set of symbols used in text input to the model. - -The default is a set of ASCII characters that works well for English or text that has been run -through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -''' -from Tacotron.text import cmudict - -_pad = '_' -_eos = '~' -_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' - -# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -_arpabet = ['@' + s for s in cmudict.valid_symbols] - -# Export all symbols: -symbols = [_pad, _eos] + list(_characters) + _arpabet - - -if __name__ == '__main__': - print(symbols) diff --git a/train.py b/train.py index 484beafc..f575bc48 100644 --- a/train.py +++ b/train.py @@ -8,14 +8,15 @@ import numpy as np import torch.nn as nn from torch import optim +from torch.autograd import Variable from torch.utils.data import DataLoader -from network import * import train_config as c from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint) from utils.model import get_param_size from datasets.LJSpeech import LJSpeechDataset +from models.tacotron import Tacotron use_cuda = torch.cuda.is_available() @@ -40,8 +41,7 @@ def main(args): c.hidden_size, c.num_mels, c.num_freq, - c.dec_out_per_step, - c.teacher_forcing_ratio) + c.dec_out_per_step) if use_cuda: model = nn.DataParallel(model.cuda()) @@ -73,10 +73,13 @@ def main(args): dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=8) + drop_last=True, num_workers=32) progbar = Progbar(len(dataset) / args.batch_size) for i, data in enumerate(dataloader): + text_input = data[0] + magnitude_input = data[1] + mel_input = data[2] current_step = i + args.restore_step + epoch * len(dataloader) + 1 @@ -84,34 +87,37 @@ def main(args): try: mel_input = np.concatenate((np.zeros( - [args.batch_size, c.num_mels, 1], dtype=np.float32), data[2][:, :, 1:]), axis=2) + [args.batch_size, 1, c.num_mels], dtype=np.float32), + mel_input[:, 1:, :]), axis=1) except: raise TypeError("not same dimension") if use_cuda: - characters = Variable(torch.from_numpy(data[0]).type( + text_input_var = Variable(torch.from_numpy(text_input).type( torch.cuda.LongTensor), requires_grad=False).cuda() - mel_input = Variable(torch.from_numpy(mel_input).type( + mel_input_var = Variable(torch.from_numpy(mel_input).type( torch.cuda.FloatTensor), requires_grad=False).cuda() - mel_spectrogram = Variable(torch.from_numpy(data[2]).type( - torch.cuda.FloatTensor), requires_grad=False).cuda() - linear_spectrogram = Variable(torch.from_numpy(data[1]).type( + mel_spec_var = Variable(torch.from_numpy(mel_input).type( torch.cuda.FloatTensor), requires_grad=False).cuda() + linear_spec_var = Variable(torch.from_numpy(magnitude_input) + .type(torch.cuda.FloatTensor), requires_grad=False).cuda() else: - characters = Variable(torch.from_numpy(data[0]).type( + text_input_var = Variable(torch.from_numpy(text_input).type( torch.LongTensor), requires_grad=False) - mel_input = Variable(torch.from_numpy(mel_input).type( + mel_input_var = Variable(torch.from_numpy(mel_input).type( torch.FloatTensor), requires_grad=False) - mel_spectrogram = Variable(torch.from_numpy( - data[2]).type(torch.FloatTensor), requires_grad=False) - linear_spectrogram = Variable(torch.from_numpy( - data[1]).type(torch.FloatTensor), requires_grad=False) + mel_spec_var = Variable(torch.from_numpy( + mel_input).type(torch.FloatTensor), requires_grad=False) + linear_spec_var = Variable(torch.from_numpy( + magnitude_input).type(torch.FloatTensor), + requires_grad=False) - mel_output, linear_output = model.forward(characters, mel_input) + mel_output, linear_output, alignments =\ + model.forward(text_input_var, mel_input_var) - mel_loss = criterion(mel_output, mel_spectrogram) - linear_loss = torch.abs(linear_output - linear_spectrogram) + mel_loss = criterion(mel_output, mel_spec_var) + linear_loss = torch.abs(linear_output - linear_spec_var) linear_loss = 0.5 * \ torch.mean(linear_loss) + 0.5 * \ torch.mean(linear_loss[:, :n_priority_freq, :]) @@ -168,9 +174,9 @@ def adjust_learning_rate(optimizer, step): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--restore_step', type=int, - help='Global step to restore checkpoint', default=0) + help='Global step to restore checkpoint', default=128) parser.add_argument('--batch_size', type=int, - help='Batch size', default=32) + help='Batch size', default=128) parser.add_argument('--config', type=str, help='path to config file for training',) args = parser.parse_args() diff --git a/train_config.py b/train_config.py index 5270d795..96a4a59d 100644 --- a/train_config.py +++ b/train_config.py @@ -10,20 +10,24 @@ ref_level_db = 20 hidden_size = 128 embedding_size = 256 +# training +epochs = 10000 +lr = 0.001 +decay_step = [500000, 1000000, 2000000] +batch_size = 128 max_iters = 200 griffin_lim_iters = 60 power = 1.5 dec_out_per_step = 5 -teacher_forcing_ratio = 1.0 +#teacher_forcing_ratio = 1.0 -epochs = 10000 -lr = 0.001 -decay_step = [500000, 1000000, 2000000] +# outputing log_step = 100 save_step = 2000 +# text processing cleaners = 'english_cleaners' +# data settings data_path = '/data/shared/KeithIto/LJSpeech-1.0/' output_path = './result' -checkpoint_path = './model_new' diff --git a/utils/.data.py.swo b/utils/.data.py.swo deleted file mode 100644 index 87e6fdba..00000000 Binary files a/utils/.data.py.swo and /dev/null differ