diff --git a/datasets/Kusal.py b/datasets/Kusal.py index f0458b84..b431aee3 100644 --- a/datasets/Kusal.py +++ b/datasets/Kusal.py @@ -8,21 +8,28 @@ import torch from torch.utils.data import Dataset from utils.text import text_to_sequence -from utils.data import (prepare_data, pad_per_step, - prepare_tensor, prepare_stop_target) +from utils.data import (prepare_data, pad_per_step, prepare_tensor, + prepare_stop_target) class MyDataset(Dataset): - - def __init__(self, root_dir, csv_file, outputs_per_step, - text_cleaner, ap, min_seq_len=0): + def __init__(self, + root_dir, + csv_file, + outputs_per_step, + text_cleaner, + ap, + min_seq_len=0): self.root_dir = root_dir self.wav_dir = os.path.join(root_dir, 'wav') self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav')) self._create_file_dict() self.csv_dir = os.path.join(root_dir, csv_file) with open(self.csv_dir, "r", encoding="utf8") as f: - self.frames = [line.split('\t') for line in f if line.split('\t')[0] in self.wav_files_dict.keys()] + self.frames = [ + line.split('\t') for line in f + if line.split('\t')[0] in self.wav_files_dict.keys() + ] self.outputs_per_step = outputs_per_step self.sample_rate = ap.sample_rate self.cleaners = text_cleaner @@ -43,10 +50,8 @@ class MyDataset(Dataset): print(" !! Cannot read file : {}".format(filename)) def _trim_silence(self, wav): - return librosa.effects.trim( - wav, top_db=40, - frame_length=1024, - hop_length=256)[0] + return librosa.effects.trim( + wav, top_db=40, frame_length=1024, hop_length=256)[0] def _create_file_dict(self): self.wav_files_dict = {} @@ -87,11 +92,10 @@ class MyDataset(Dataset): sidx = self.frames[idx][0] sidx_files = self.wav_files_dict[sidx] file_name = random.choice(sidx_files) - wav_name = os.path.join(self.wav_dir, - file_name) + wav_name = os.path.join(self.wav_dir, file_name) text = self.frames[idx][2] - text = np.asarray(text_to_sequence( - text, [self.cleaners]), dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} return sample @@ -121,12 +125,13 @@ class MyDataset(Dataset): mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame # compute 'stop token' targets - stop_targets = [np.array([0.]*(mel_len-1)) - for mel_len in mel_lengths] + stop_targets = [ + np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths + ] # PAD stop targets - stop_targets = prepare_stop_target( - stop_targets, self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, + self.outputs_per_step) # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) @@ -150,8 +155,8 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ + 0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}" - .format(type(batch[0])))) + found {}".format(type(batch[0])))) diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index 377855dc..d8eb66b4 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -6,14 +6,18 @@ import torch from torch.utils.data import Dataset from utils.text import text_to_sequence -from utils.data import (prepare_data, pad_per_step, - prepare_tensor, prepare_stop_target) +from utils.data import (prepare_data, pad_per_step, prepare_tensor, + prepare_stop_target) class MyDataset(Dataset): - - def __init__(self, root_dir, csv_file, outputs_per_step, - text_cleaner, ap, min_seq_len=0): + def __init__(self, + root_dir, + csv_file, + outputs_per_step, + text_cleaner, + ap, + min_seq_len=0): self.root_dir = root_dir self.wav_dir = os.path.join(root_dir, 'wavs') self.csv_dir = os.path.join(root_dir, csv_file) @@ -60,11 +64,10 @@ class MyDataset(Dataset): return len(self.frames) def __getitem__(self, idx): - wav_name = os.path.join(self.wav_dir, - self.frames[idx][0]) + '.wav' + wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' text = self.frames[idx][1] - text = np.asarray(text_to_sequence( - text, [self.cleaners]), dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name), dtype=np.float32) sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} return sample @@ -94,12 +97,13 @@ class MyDataset(Dataset): mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame # compute 'stop token' targets - stop_targets = [np.array([0.]*(mel_len-1)) - for mel_len in mel_lengths] + stop_targets = [ + np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths + ] # PAD stop targets - stop_targets = prepare_stop_target( - stop_targets, self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, + self.outputs_per_step) # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) @@ -123,8 +127,8 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ + 0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}" - .format(type(batch[0])))) + found {}".format(type(batch[0])))) diff --git a/datasets/LJSpeechCached.py b/datasets/LJSpeechCached.py index c0018493..09061cc3 100644 --- a/datasets/LJSpeechCached.py +++ b/datasets/LJSpeechCached.py @@ -6,14 +6,18 @@ import torch from torch.utils.data import Dataset from utils.text import text_to_sequence -from utils.data import (prepare_data, pad_per_step, - prepare_tensor, prepare_stop_target) +from utils.data import (prepare_data, pad_per_step, prepare_tensor, + prepare_stop_target) class MyDataset(Dataset): - - def __init__(self, root_dir, csv_file, outputs_per_step, - text_cleaner, ap, min_seq_len=0): + def __init__(self, + root_dir, + csv_file, + outputs_per_step, + text_cleaner, + ap, + min_seq_len=0): self.root_dir = root_dir self.wav_dir = os.path.join(root_dir, 'wavs') self.feat_dir = os.path.join(root_dir, 'loader_data') @@ -35,7 +39,7 @@ class MyDataset(Dataset): return audio except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) - + def load_np(self, filename): data = np.load(filename).astype('float32') return data @@ -66,20 +70,24 @@ class MyDataset(Dataset): def __getitem__(self, idx): if self.items[idx] is None: - wav_name = os.path.join(self.wav_dir, - self.frames[idx][0]) + '.wav' + wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav' mel_name = os.path.join(self.feat_dir, self.frames[idx][0]) + '.mel.npy' linear_name = os.path.join(self.feat_dir, self.frames[idx][0]) + '.linear.npy' text = self.frames[idx][1] - text = np.asarray(text_to_sequence( - text, [self.cleaners]), dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) mel = self.load_np(mel_name) linear = self.load_np(linear_name) - sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0], - 'mel':mel, 'linear': linear} + sample = { + 'text': text, + 'wav': wav, + 'item_idx': self.frames[idx][0], + 'mel': mel, + 'linear': linear + } self.items[idx] = sample else: sample = self.items[idx] @@ -109,12 +117,13 @@ class MyDataset(Dataset): mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame # compute 'stop token' targets - stop_targets = [np.array([0.]*(mel_len-1)) - for mel_len in mel_lengths] + stop_targets = [ + np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths + ] # PAD stop targets - stop_targets = prepare_stop_target( - stop_targets, self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, + self.outputs_per_step) # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) @@ -138,8 +147,8 @@ class MyDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ + 0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}" - .format(type(batch[0])))) + found {}".format(type(batch[0])))) diff --git a/datasets/TWEB.py b/datasets/TWEB.py index eeae551a..50cdb448 100644 --- a/datasets/TWEB.py +++ b/datasets/TWEB.py @@ -7,15 +7,25 @@ from torch.utils.data import Dataset from TTS.utils.text import text_to_sequence from TTS.utils.audio import AudioProcessor -from TTS.utils.data import (prepare_data, pad_per_step, - prepare_tensor, prepare_stop_target) +from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor, + prepare_stop_target) class TWEBDataset(Dataset): - - def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, - text_cleaner, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power, + def __init__(self, + csv_file, + root_dir, + outputs_per_step, + sample_rate, + text_cleaner, + num_mels, + min_level_db, + frame_shift_ms, + frame_length_ms, + preemphasis, + ref_level_db, + num_freq, + power, min_seq_len=0): with open(csv_file, "r") as f: @@ -25,8 +35,9 @@ class TWEBDataset(Dataset): self.sample_rate = sample_rate self.cleaners = text_cleaner self.min_seq_len = min_seq_len - self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power) + self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, + frame_shift_ms, frame_length_ms, preemphasis, + ref_level_db, num_freq, power) print(" > Reading TWEB from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.frames))) self._sort_frames() @@ -63,11 +74,10 @@ class TWEBDataset(Dataset): return len(self.frames) def __getitem__(self, idx): - wav_name = os.path.join(self.root_dir, - self.frames[idx][0]) + '.wav' + wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav' text = self.frames[idx][1] - text = np.asarray(text_to_sequence( - text, [self.cleaners]), dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners]), dtype=np.int32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} return sample @@ -97,12 +107,13 @@ class TWEBDataset(Dataset): mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame # compute 'stop token' targets - stop_targets = [np.array([0.]*(mel_len-1)) - for mel_len in mel_lengths] + stop_targets = [ + np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths + ] # PAD stop targets - stop_targets = prepare_stop_target( - stop_targets, self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, + self.outputs_per_step) # PAD sequences with largest length of the batch text = prepare_data(text).astype(np.int32) @@ -126,8 +137,8 @@ class TWEBDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] + return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[ + 0] raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}" - .format(type(batch[0])))) + found {}".format(type(batch[0])))) diff --git a/debug_config.py b/debug_config.py index 7a9a94ab..51f08ce8 100644 --- a/debug_config.py +++ b/debug_config.py @@ -10,7 +10,6 @@ "hidden_size": 128, "embedding_size": 256, "text_cleaner": "english_cleaners", - "epochs": 200, "lr": 0.01, "lr_patience": 2, @@ -19,9 +18,7 @@ "griffinf_lim_iters": 60, "power": 1.5, "r": 5, - "num_loader_workers": 16, - "save_step": 1, "data_path": "/data/shared/KeithIto/LJSpeech-1.0", "output_path": "result", diff --git a/extract_feats.py b/extract_feats.py index 0a4f99c4..f075110c 100644 --- a/extract_feats.py +++ b/extract_feats.py @@ -13,19 +13,19 @@ from utils.generic_utils import load_config from multiprocessing import Pool - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--data_path', type=str, - help='Data folder.') - parser.add_argument('--out_path', type=str, - help='Output folder.') - parser.add_argument('--config', type=str, - help='conf.json file for run settings.') - parser.add_argument("--num_proc", type=int, default=8, - help="number of processes.") - parser.add_argument("--trim_silence", type=bool, default=False, - help="trim silence in the voice clip.") + parser.add_argument('--data_path', type=str, help='Data folder.') + parser.add_argument('--out_path', type=str, help='Output folder.') + parser.add_argument( + '--config', type=str, help='conf.json file for run settings.') + parser.add_argument( + "--num_proc", type=int, default=8, help="number of processes.") + parser.add_argument( + "--trim_silence", + type=bool, + default=False, + help="trim silence in the voice clip.") args = parser.parse_args() DATA_PATH = args.data_path OUT_PATH = args.out_path @@ -34,27 +34,26 @@ if __name__ == "__main__": print(" > Input path: ", DATA_PATH) print(" > Output path: ", OUT_PATH) - audio = importlib.import_module('utils.'+c.audio_processor) + audio = importlib.import_module('utils.' + c.audio_processor) AudioProcessor = getattr(audio, 'AudioProcessor') - ap = AudioProcessor(sample_rate = CONFIG.sample_rate, - num_mels = CONFIG.num_mels, - min_level_db = CONFIG.min_level_db, - frame_shift_ms = CONFIG.frame_shift_ms, - frame_length_ms = CONFIG.frame_length_ms, - ref_level_db = CONFIG.ref_level_db, - num_freq = CONFIG.num_freq, - power = CONFIG.power, - preemphasis = CONFIG.preemphasis, - min_mel_freq = CONFIG.min_mel_freq, - max_mel_freq = CONFIG.max_mel_freq) + ap = AudioProcessor( + sample_rate=CONFIG.sample_rate, + num_mels=CONFIG.num_mels, + min_level_db=CONFIG.min_level_db, + frame_shift_ms=CONFIG.frame_shift_ms, + frame_length_ms=CONFIG.frame_length_ms, + ref_level_db=CONFIG.ref_level_db, + num_freq=CONFIG.num_freq, + power=CONFIG.power, + preemphasis=CONFIG.preemphasis, + min_mel_freq=CONFIG.min_mel_freq, + max_mel_freq=CONFIG.max_mel_freq) def trim_silence(self, wav): margin = int(CONFIG.sample_rate * 0.1) wav = wav[margin:-margin] return librosa.effects.trim( - wav, top_db=40, - frame_length=1024, - hop_length=256)[0] + wav, top_db=40, frame_length=1024, hop_length=256)[0] def extract_mel(file_path): # x, fs = sf.read(file_path) @@ -63,23 +62,25 @@ if __name__ == "__main__": x = trim_silence(x) mel = ap.melspectrogram(x.astype('float32')).astype('float32') linear = ap.spectrogram(x.astype('float32')).astype('float32') - file_name = os.path.basename(file_path).replace(".wav","") + file_name = os.path.basename(file_path).replace(".wav", "") mel_file = file_name + ".mel" linear_file = file_name + ".linear" np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False) - np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False) + np.save( + os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False) mel_len = mel.shape[1] linear_len = linear.shape[1] wav_len = x.shape[0] print(" > " + file_path, flush=True) - return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len) + return file_path, mel_file, linear_file, str(wav_len), str( + mel_len), str(linear_len) glob_path = os.path.join(DATA_PATH, "*.wav") print(" > Reading wav: {}".format(glob_path)) file_names = glob.glob(glob_path, recursive=True) if __name__ == "__main__": - print(" > Number of files: %i"%(len(file_names))) + print(" > Number of files: %i" % (len(file_names))) if not os.path.exists(OUT_PATH): os.makedirs(OUT_PATH) print(" > A new folder created at {}".format(OUT_PATH)) @@ -88,7 +89,10 @@ if __name__ == "__main__": if args.num_proc > 1: print(" > Using {} processes.".format(args.num_proc)) with Pool(args.num_proc) as p: - r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names))) + r = list( + tqdm.tqdm( + p.imap(extract_mel, file_names), + total=len(file_names))) # r = list(p.imap(extract_mel, file_names)) else: print(" > Using single process run.") @@ -100,5 +104,5 @@ if __name__ == "__main__": file = open(file_path, "w") for line in r: line = ", ".join(line) - file.write(line+'\n') + file.write(line + '\n') file.close() diff --git a/layers/attention.py b/layers/attention.py index 8b445da8..494d14dc 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -24,8 +24,8 @@ class BahdanauAttention(nn.Module): processed_query = self.query_layer(query) processed_annots = self.annot_layer(annots) # (batch, max_time, 1) - alignment = self.v(nn.functional.tanh( - processed_query + processed_annots)) + alignment = self.v( + nn.functional.tanh(processed_query + processed_annots)) # (batch, max_time) return alignment.squeeze(-1) @@ -33,15 +33,24 @@ class BahdanauAttention(nn.Module): class LocationSensitiveAttention(nn.Module): """Location sensitive attention following https://arxiv.org/pdf/1506.07503.pdf""" - def __init__(self, annot_dim, query_dim, attn_dim, - kernel_size=7, filters=20): + + def __init__(self, + annot_dim, + query_dim, + attn_dim, + kernel_size=7, + filters=20): super(LocationSensitiveAttention, self).__init__() self.kernel_size = kernel_size self.filters = filters padding = int((kernel_size - 1) / 2) - self.loc_conv = nn.Conv1d(2, filters, - kernel_size=kernel_size, stride=1, - padding=padding, bias=False) + self.loc_conv = nn.Conv1d( + 2, + filters, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=False) self.loc_linear = nn.Linear(filters, attn_dim) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) @@ -62,8 +71,9 @@ class LocationSensitiveAttention(nn.Module): processed_loc = self.loc_linear(loc_conv) processed_query = self.query_layer(query) processed_annots = self.annot_layer(annot) - alignment = self.v(nn.functional.tanh( - processed_query + processed_annots + processed_loc)) + alignment = self.v( + nn.functional.tanh(processed_query + processed_annots + + processed_loc)) # (batch, max_time) return alignment.squeeze(-1) @@ -85,16 +95,23 @@ class AttentionRNNCell(nn.Module): self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim) # pick bahdanau or location sensitive attention if align_model == 'b': - self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim) + self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, + out_dim) if align_model == 'ls': - self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim) + self.alignment_model = LocationSensitiveAttention( + annot_dim, rnn_dim, out_dim) else: raise RuntimeError(" Wrong alignment model name: {}. Use\ - 'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model)) + 'b' (Bahdanau) or 'ls' (Location Sensitive)." + .format(align_model)) - - def forward(self, memory, context, rnn_state, annots, - atten, annot_lens=None): + def forward(self, + memory, + context, + rnn_state, + annots, + atten, + annot_lens=None): """ Shapes: - memory: (batch, 1, dim) or (batch, dim) diff --git a/layers/custom_layers.py b/layers/custom_layers.py index e7f52d7c..e1fde912 100644 --- a/layers/custom_layers.py +++ b/layers/custom_layers.py @@ -2,7 +2,6 @@ import torch from torch import nn - # class StopProjection(nn.Module): # r""" Simple projection layer to predict the "stop token" diff --git a/layers/losses.py b/layers/losses.py index d7b21c38..acf4e789 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -5,7 +5,6 @@ from utils.generic_utils import sequence_mask class L1LossMasked(nn.Module): - def __init__(self): super(L1LossMasked, self).__init__() @@ -31,21 +30,20 @@ class L1LossMasked(nn.Module): # target_flat: (batch * max_len, dim) target_flat = target.view(-1, target.shape[-1]) # losses_flat: (batch * max_len, dim) - losses_flat = functional.l1_loss(input, target_flat, size_average=False, - reduce=False) + losses_flat = functional.l1_loss( + input, target_flat, size_average=False, reduce=False) # losses: (batch, max_len, dim) losses = losses_flat.view(*target.size()) # mask: (batch, max_len, 1) - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2) + mask = sequence_mask( + sequence_length=length, max_len=target.size(1)).unsqueeze(2) losses = losses * mask.float() loss = losses.sum() / (length.float().sum() * float(target.shape[2])) return loss class MSELossMasked(nn.Module): - def __init__(self): super(MSELossMasked, self).__init__() @@ -71,14 +69,14 @@ class MSELossMasked(nn.Module): # target_flat: (batch * max_len, dim) target_flat = target.view(-1, target.shape[-1]) # losses_flat: (batch * max_len, dim) - losses_flat = functional.mse_loss(input, target_flat, size_average=False, - reduce=False) + losses_flat = functional.mse_loss( + input, target_flat, size_average=False, reduce=False) # losses: (batch, max_len, dim) losses = losses_flat.view(*target.size()) # mask: (batch, max_len, 1) - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2) + mask = sequence_mask( + sequence_length=length, max_len=target.size(1)).unsqueeze(2) losses = losses * mask.float() loss = losses.sum() / (length.float().sum() * float(target.shape[2])) return loss diff --git a/layers/tacotron.py b/layers/tacotron.py index c5b1d2ff..df53e44c 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -3,6 +3,7 @@ import torch from torch import nn from .attention import AttentionRNNCell + class Prenet(nn.Module): r""" Prenet as explained at https://arxiv.org/abs/1703.10135. It creates as many layers as given by 'out_features' @@ -16,9 +17,10 @@ class Prenet(nn.Module): def __init__(self, in_features, out_features=[256, 128]): super(Prenet, self).__init__() in_features = [in_features] + out_features[:-1] - self.layers = nn.ModuleList( - [nn.Linear(in_size, out_size) - for (in_size, out_size) in zip(in_features, out_features)]) + self.layers = nn.ModuleList([ + nn.Linear(in_size, out_size) + for (in_size, out_size) in zip(in_features, out_features) + ]) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) @@ -46,12 +48,21 @@ class BatchNormConv1d(nn.Module): - output: batch x dims """ - def __init__(self, in_channels, out_channels, kernel_size, stride, padding, + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, activation=None): super(BatchNormConv1d, self).__init__() - self.conv1d = nn.Conv1d(in_channels, out_channels, - kernel_size=kernel_size, - stride=stride, padding=padding, bias=False) + self.conv1d = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False) # Following tensorflow's default parameters self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.activation = activation @@ -96,16 +107,25 @@ class CBHG(nn.Module): - output: batch x time x dim*2 """ - def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4): + def __init__(self, + in_features, + K=16, + projections=[128, 128], + num_highways=4): super(CBHG, self).__init__() self.in_features = in_features self.relu = nn.ReLU() # list of conv1d bank with filter size k=1...K # TODO: try dilational layers instead - self.conv1d_banks = nn.ModuleList( - [BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1, - padding=k // 2, activation=self.relu) - for k in range(1, K + 1)]) + self.conv1d_banks = nn.ModuleList([ + BatchNormConv1d( + in_features, + in_features, + kernel_size=k, + stride=1, + padding=k // 2, + activation=self.relu) for k in range(1, K + 1) + ]) # max pooling of conv bank # TODO: try average pooling OR larger kernel size self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) @@ -114,9 +134,15 @@ class CBHG(nn.Module): activations += [None] # setup conv1d projection layers layer_set = [] - for (in_size, out_size, ac) in zip(out_features, projections, activations): - layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, - padding=1, activation=ac) + for (in_size, out_size, ac) in zip(out_features, projections, + activations): + layer = BatchNormConv1d( + in_size, + out_size, + kernel_size=3, + stride=1, + padding=1, + activation=ac) layer_set.append(layer) self.conv1d_projections = nn.ModuleList(layer_set) # setup Highway layers @@ -204,10 +230,14 @@ class Decoder(nn.Module): # memory -> |Prenet| -> processed_memory self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State - self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features, - memory_dim=128, align_model='ls') + self.attention_rnn = AttentionRNNCell( + out_dim=128, + rnn_dim=256, + annot_dim=in_features, + memory_dim=128, + align_model='ls') # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input - self.project_to_decoder_in = nn.Linear(256+in_features, 256) + self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) @@ -241,17 +271,20 @@ class Decoder(nn.Module): # Grouping multiple frames if necessary if memory.size(-1) == self.memory_dim: memory = memory.view(B, memory.size(1) // self.r, -1) - " !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), - self.memory_dim, self.r) + " !! Dimension mismatch {} vs {} * {}".format( + memory.size(-1), self.memory_dim, self.r) T_decoder = memory.size(1) # go frame as zeros matrix initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() # decoder states attention_rnn_hidden = inputs.data.new(B, 256).zero_() - decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() - for _ in range(len(self.decoder_rnns))] + decoder_rnn_hiddens = [ + inputs.data.new(B, 256).zero_() + for _ in range(len(self.decoder_rnns)) + ] current_context_vec = inputs.data.new(B, self.in_features).zero_() - stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() + stopnet_rnn_hidden = inputs.data.new(B, + self.r * self.memory_dim).zero_() # attention states attention = inputs.data.new(B, T).zero_() attention_cum = inputs.data.new(B, T).zero_() @@ -268,13 +301,12 @@ class Decoder(nn.Module): if greedy: memory_input = outputs[-1] else: - memory_input = memory[t-1] + memory_input = memory[t - 1] # Prenet processed_memory = self.prenet(memory_input) # Attention RNN - attention_cat = torch.cat((attention.unsqueeze(1), - attention_cum.unsqueeze(1)), - dim=1) + attention_cat = torch.cat( + (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention_cat, input_lens) @@ -293,16 +325,18 @@ class Decoder(nn.Module): output = self.proj_to_mel(decoder_output) stop_input = output # predict stop token - stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) + stop_token, stopnet_rnn_hidden = self.stopnet( + stop_input, stopnet_rnn_hidden) outputs += [output] attentions += [attention] stop_tokens += [stop_token] t += 1 - if (not greedy and self.training) or (greedy and memory is not None): + if (not greedy and self.training) or (greedy + and memory is not None): if t >= T_decoder: break else: - if t > inputs.shape[1]/2 and stop_token > 0.6: + if t > inputs.shape[1] / 2 and stop_token > 0.6: break elif t > self.max_decoder_steps: print(" | | > Decoder stopped with 'max_decoder_steps") diff --git a/models/tacotron.py b/models/tacotron.py index d07bdd6f..b1b67162 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -6,14 +6,18 @@ from layers.tacotron import Prenet, Encoder, Decoder, CBHG class Tacotron(nn.Module): - def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, - r=5, padding_idx=None): + def __init__(self, + embedding_dim=256, + linear_dim=1025, + mel_dim=80, + r=5, + padding_idx=None): super(Tacotron, self).__init__() self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim - self.embedding = nn.Embedding(len(symbols), embedding_dim, - padding_idx=padding_idx) + self.embedding = nn.Embedding( + len(symbols), embedding_dim, padding_idx=padding_idx) print(" | > Number of characters : {}".format(len(symbols))) self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(embedding_dim) diff --git a/notebooks/synthesis.py b/notebooks/synthesis.py index 3b380aa5..ad51a024 100644 --- a/notebooks/synthesis.py +++ b/notebooks/synthesis.py @@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG): plt.plot(range(len(stop_tokens)), list(stop_tokens)) plt.subplot(3, 1, 3) - librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate, - hop_length=hop_length, x_axis="time", y_axis="linear") + librosa.display.specshow( + spectrogram.T, + sr=CONFIG.sample_rate, + hop_length=hop_length, + x_axis="time", + y_axis="linear") plt.xlabel("Time", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize) plt.tight_layout() diff --git a/server/README.md b/server/README.md index 974b695b..97bf811a 100644 --- a/server/README.md +++ b/server/README.md @@ -1,9 +1,9 @@ ## TTS example web-server Steps to run: -1. Download one of the models given on the main page. -2. Checkout the corresponding commit history. -2. Set paths and other options in the file ```server/conf.json```. -3. Run the server ```python server/server.py -c conf.json```. (Requires Flask) +1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model. +2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model. +2. Set the paths and the other options in the file ```server/conf.json```. +3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask) 4. Go to ```localhost:[given_port]``` and enjoy. -Note that the audio quality on browser is slightly worse due to the encoder quantization. \ No newline at end of file +For high quality results, please use the library versions shown in the ```requirements.txt``` file. \ No newline at end of file diff --git a/server/conf.json b/server/conf.json index 8c257811..f1589f8d 100644 --- a/server/conf.json +++ b/server/conf.json @@ -1,5 +1,5 @@ { - "model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7", + "model_path":"../models/May-22-2018_03_24PM-e6112f7", "model_name":"checkpoint_272976.pth.tar", "model_config":"config.json", "port": 5002, diff --git a/server/server.py b/server/server.py index 01267447..86ccbc42 100644 --- a/server/server.py +++ b/server/server.py @@ -2,12 +2,11 @@ import argparse from synthesizer import Synthesizer from TTS.utils.generic_utils import load_config -from flask import (Flask, Response, request, - render_template, send_file) +from flask import Flask, Response, request, render_template, send_file parser = argparse.ArgumentParser() -parser.add_argument('-c', '--config_path', type=str, - help='path to config file for training') +parser.add_argument( + '-c', '--config_path', type=str, help='path to config file for training') args = parser.parse_args() config = load_config(args.config_path) @@ -16,17 +15,19 @@ synthesizer = Synthesizer() synthesizer.load_model(config.model_path, config.model_name, config.model_config, config.use_cuda) + @app.route('/') def index(): return render_template('index.html') + @app.route('/api/tts', methods=['GET']) def tts(): text = request.args.get('text') print(" > Model input: {}".format(text)) data = synthesizer.tts(text) - return send_file(data, - mimetype='audio/wav') + return send_file(data, mimetype='audio/wav') + if __name__ == '__main__': - app.run(debug=True, host='0.0.0.0', port=config.port) \ No newline at end of file + app.run(debug=True, host='0.0.0.0', port=config.port) diff --git a/server/synthesizer.py b/server/synthesizer.py index 808107fe..534b1313 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -13,39 +13,44 @@ from matplotlib import pylab as plt class Synthesizer(object): - def load_model(self, model_path, model_name, model_config, use_cuda): model_config = os.path.join(model_path, model_config) - self.model_file = os.path.join(model_path, model_name) + self.model_file = os.path.join(model_path, model_name) print(" > Loading model ...") print(" | > model config: ", model_config) print(" | > model file: ", self.model_file) config = load_config(model_config) self.config = config self.use_cuda = use_cuda - self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r) - self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db, - config.frame_shift_ms, config.frame_length_ms, config.preemphasis, - config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60) + self.model = Tacotron(config.embedding_size, config.num_freq, + config.num_mels, config.r) + self.ap = AudioProcessor( + config.sample_rate, + config.num_mels, + config.min_level_db, + config.frame_shift_ms, + config.frame_length_ms, + config.preemphasis, + config.ref_level_db, + config.num_freq, + config.power, + griffin_lim_iters=60) # load model state if use_cuda: cp = torch.load(self.model_file) else: - cp = torch.load(self.model_file, map_location=lambda storage, loc: storage) + cp = torch.load( + self.model_file, map_location=lambda storage, loc: storage) # load the model self.model.load_state_dict(cp['model']) if use_cuda: self.model.cuda() - self.model.eval() - + self.model.eval() + def save_wav(self, wav, path): wav *= 32767 / max(1e-8, np.max(np.abs(wav))) - # sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav') - # wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None) - # wav = wav / wav.max() - # sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg') - scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16)) - # librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True) + librosa.output.write_wav(path, wav.astype(np.int16), + self.config.sample_rate) def tts(self, text): text_cleaner = [self.config.text_cleaner] @@ -54,14 +59,15 @@ class Synthesizer(object): if len(sen) < 3: continue sen = sen.strip() - sen +='.' + sen += '.' print(sen) sen = sen.strip() seq = np.array(text_to_sequence(text, text_cleaner)) - chars_var = torch.from_numpy(seq).unsqueeze(0) + chars_var = torch.from_numpy(seq).unsqueeze(0).long() if self.use_cuda: chars_var = chars_var.cuda() - mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var) + mel_out, linear_out, alignments, stop_tokens = self.model.forward( + chars_var) linear_out = linear_out[0].data.cpu().numpy() wav = self.ap.inv_spectrogram(linear_out.T) # wav = wav[:self.ap.find_endpoint(wav)] diff --git a/setup.py b/setup.py index 4773d00e..1574094b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,6 @@ else: class build_py(setuptools.command.build_py.build_py): - def run(self): self.create_version_file() setuptools.command.build_py.build_py.run(self) @@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py): class develop(setuptools.command.develop.develop): - def run(self): build_py.create_version_file() setuptools.command.develop.develop.run(self) @@ -50,8 +48,11 @@ def create_readme_rst(): global cwd try: subprocess.check_call( - ["pandoc", "--from=markdown", "--to=rst", "--output=README.rst", - "README.md"], cwd=cwd) + [ + "pandoc", "--from=markdown", "--to=rst", "--output=README.rst", + "README.md" + ], + cwd=cwd) print("Generated README.rst from README.md using pandoc.") except subprocess.CalledProcessError: pass @@ -59,33 +60,31 @@ def create_readme_rst(): pass -setup(name='TTS', - version=version, - url='https://github.com/mozilla/TTS', - description='Text to Speech with Deep Learning', - - packages=find_packages(), - cmdclass={ - 'build_py': build_py, - 'develop': develop, - }, - setup_requires=[ - "numpy" - ], - install_requires=[ - "scipy", - "torch == 0.4.0", - "librosa", - "unidecode", - "tensorboardX", - "matplotlib", - "Pillow", - "flask", - "lws", - ], - extras_require={ - "bin": [ - "tqdm", - "requests", - ], - }) \ No newline at end of file +setup( + name='TTS', + version=version, + url='https://github.com/mozilla/TTS', + description='Text to Speech with Deep Learning', + packages=find_packages(), + cmdclass={ + 'build_py': build_py, + 'develop': develop, + }, + setup_requires=["numpy"], + install_requires=[ + "scipy", + "torch == 0.4.0", + "librosa", + "unidecode", + "tensorboardX", + "matplotlib", + "Pillow", + "flask", + "lws", + ], + extras_require={ + "bin": [ + "tqdm", + "requests", + ], + }) diff --git a/tests/generic_utils_text.py b/tests/generic_utils_text.py index b1eb2ddc..63b8bc1a 100644 --- a/tests/generic_utils_text.py +++ b/tests/generic_utils_text.py @@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar' class ModelSavingTests(unittest.TestCase): - def save_checkpoint_test(self): # create a dummy model model = Prenet(128, out_features=[256, 128]) model = T.nn.DataParallel(layer) # save the model - save_checkpoint(model, None, 100, - OUTPATH, 1, 1) + save_checkpoint(model, None, 100, OUTPATH, 1, 1) # load the model to CPU - model_dict = torch.load(MODEL_PATH, map_location=lambda storage, - loc: storage) + model_dict = torch.load( + MODEL_PATH, map_location=lambda storage, loc: storage) model.load_state_dict(model_dict['model']) def save_best_model_test(self): @@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase): model = T.nn.DataParallel(layer) # save the model - best_loss = save_best_model(model, None, 0, - 100, OUT_PATH, - 10, 1) + best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1) # load the model to CPU - model_dict = torch.load(MODEL_PATH, map_location=lambda storage, - loc: storage) + model_dict = torch.load( + MODEL_PATH, map_location=lambda storage, loc: storage) model.load_state_dict(model_dict['model']) diff --git a/tests/layers_tests.py b/tests/layers_tests.py index 9b5c3f73..28cb5cf2 100644 --- a/tests/layers_tests.py +++ b/tests/layers_tests.py @@ -7,7 +7,6 @@ from TTS.utils.generic_utils import sequence_mask class PrenetTests(unittest.TestCase): - def test_in_out(self): layer = Prenet(128, out_features=[256, 128]) dummy_input = T.rand(4, 128) @@ -19,7 +18,6 @@ class PrenetTests(unittest.TestCase): class CBHGTests(unittest.TestCase): - def test_in_out(self): layer = CBHG(128, K=6, projections=[128, 128], num_highways=2) dummy_input = T.rand(4, 8, 128) @@ -32,7 +30,6 @@ class CBHGTests(unittest.TestCase): class DecoderTests(unittest.TestCase): - def test_in_out(self): layer = Decoder(in_features=256, memory_dim=80, r=2) dummy_input = T.rand(4, 8, 256) @@ -49,7 +46,6 @@ class DecoderTests(unittest.TestCase): class EncoderTests(unittest.TestCase): - def test_in_out(self): layer = Encoder(128) dummy_input = T.rand(4, 8, 128) @@ -63,7 +59,6 @@ class EncoderTests(unittest.TestCase): class L1LossMaskedTests(unittest.TestCase): - def test_in_out(self): layer = L1LossMasked() dummy_input = T.ones(4, 8, 128).float() @@ -80,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase): dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ((sequence_mask(dummy_length).float() - 1.0) - * 100.0).unsqueeze(2) + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) diff --git a/tests/loader_tests.py b/tests/loader_tests.py index 927126b4..5d5cbe52 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -12,34 +12,38 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TestLJSpeechDataset(unittest.TestCase): - def __init__(self, *args, **kwargs): super(TestLJSpeechDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 - self.ap = AudioProcessor(sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis, - min_mel_freq=c.min_mel_freq, - max_mel_freq=c.max_mel_freq) + self.ap = AudioProcessor( + sample_rate=c.sample_rate, + num_mels=c.num_mels, + min_level_db=c.min_level_db, + frame_shift_ms=c.frame_shift_ms, + frame_length_ms=c.frame_length_ms, + ref_level_db=c.ref_level_db, + num_freq=c.num_freq, + power=c.power, + preemphasis=c.preemphasis, + min_mel_freq=c.min_mel_freq, + max_mel_freq=c.max_mel_freq) def test_loader(self): - dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - c.r, - c.text_cleaner, - ap = self.ap, - min_seq_len=c.min_seq_len - ) + dataset = LJSpeech.MyDataset( + os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + c.r, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - dataloader = DataLoader(dataset, batch_size=2, - shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -62,18 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech), - os.path.join(c.data_path_LJSpeech, 'metadata.csv'), - 1, - c.text_cleaner, - ap = self.ap, - min_seq_len=c.min_seq_len - ) + dataset = LJSpeech.MyDataset( + os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + 1, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) # Test for batch size 1 - dataloader = DataLoader(dataset, batch_size=1, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -98,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_lengths[0] == mel_input[0].shape[0] # Test for batch size 2 - dataloader = DataLoader(dataset, batch_size=2, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -130,9 +142,9 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_lengths[idx] == mel_input[idx].shape[0] # check the second itme in the batch - assert mel_input[1-idx, -1].sum() == 0 - assert linear_input[1-idx, -1].sum() == 0 - assert stop_target[1-idx, -1] == 1 + assert mel_input[1 - idx, -1].sum() == 0 + assert linear_input[1 - idx, -1].sum() == 0 + assert stop_target[1 - idx, -1] == 1 assert len(mel_lengths.shape) == 1 # check batch conditions @@ -141,34 +153,38 @@ class TestLJSpeechDataset(unittest.TestCase): class TestKusalDataset(unittest.TestCase): - def __init__(self, *args, **kwargs): super(TestKusalDataset, self).__init__(*args, **kwargs) self.max_loader_iter = 4 - self.ap = AudioProcessor(sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis, - min_mel_freq=c.min_mel_freq, - max_mel_freq=c.max_mel_freq) + self.ap = AudioProcessor( + sample_rate=c.sample_rate, + num_mels=c.num_mels, + min_level_db=c.min_level_db, + frame_shift_ms=c.frame_shift_ms, + frame_length_ms=c.frame_length_ms, + ref_level_db=c.ref_level_db, + num_freq=c.num_freq, + power=c.power, + preemphasis=c.preemphasis, + min_mel_freq=c.min_mel_freq, + max_mel_freq=c.max_mel_freq) def test_loader(self): - dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - c.r, - c.text_cleaner, - ap = self.ap, - min_seq_len=c.min_seq_len - ) + dataset = Kusal.MyDataset( + os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + c.r, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) - dataloader = DataLoader(dataset, batch_size=2, - shuffle=True, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -191,18 +207,22 @@ class TestKusalDataset(unittest.TestCase): assert mel_input.shape[2] == c.num_mels def test_padding(self): - dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal), - os.path.join(c.data_path_Kusal, 'prompts.txt'), - 1, - c.text_cleaner, - ap = self.ap, - min_seq_len=c.min_seq_len - ) + dataset = Kusal.MyDataset( + os.path.join(c.data_path_Kusal), + os.path.join(c.data_path_Kusal, 'prompts.txt'), + 1, + c.text_cleaner, + ap=self.ap, + min_seq_len=c.min_seq_len) # Test for batch size 1 - dataloader = DataLoader(dataset, batch_size=1, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=True, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -227,9 +247,13 @@ class TestKusalDataset(unittest.TestCase): assert mel_lengths[0] == mel_input[0].shape[0] # Test for batch size 2 - dataloader = DataLoader(dataset, batch_size=2, - shuffle=False, collate_fn=dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers) + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + num_workers=c.num_loader_workers) for i, data in enumerate(dataloader): if i == self.max_loader_iter: @@ -259,16 +283,16 @@ class TestKusalDataset(unittest.TestCase): assert mel_lengths[idx] == mel_input[idx].shape[0] # check the second itme in the batch - assert mel_input[1-idx, -1].sum() == 0 - assert linear_input[1-idx, -1].sum() == 0 - assert stop_target[1-idx, -1] == 1 + assert mel_input[1 - idx, -1].sum() == 0 + assert linear_input[1 - idx, -1].sum() == 0 + assert stop_target[1 - idx, -1] == 1 assert len(mel_lengths.shape) == 1 # check batch conditions assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 - + # class TestTWEBDataset(unittest.TestCase): # def __init__(self, *args, **kwargs): @@ -339,7 +363,7 @@ class TestKusalDataset(unittest.TestCase): # for i, data in enumerate(dataloader): # if i == self.max_loader_iter: # break - + # text_input = data[0] # text_lengths = data[1] # linear_input = data[2] @@ -399,4 +423,4 @@ class TestKusalDataset(unittest.TestCase): # # check batch conditions # assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 -# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 \ No newline at end of file +# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index e1b9c91d..52a5dfcd 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TacotronTrainTest(unittest.TestCase): - def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) mel_spec = torch.rand(8, 30, c.num_mels).to(device) linear_spec = torch.rand(8, 30, c.num_freq).to(device) - mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) - + for idx in mel_lengths: stop_targets[:, int(idx.item()):, 0] = 1.0 - - stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1) + + stop_targets = stop_targets.view(input.shape[0], + stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() - + criterion = L1LossMasked().to(device) criterion_st = nn.BCELoss().to(device) - model = Tacotron(c.embedding_size, - c.num_freq, - c.num_mels, + model = Tacotron(c.embedding_size, c.num_freq, c.num_mels, c.r).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): - mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) + mel_out, linear_out, align, stop_tokens = model.forward( + input, mel_spec) assert stop_tokens.data.max() <= 1.0 assert stop_tokens.data.min() >= 0.0 optimizer.zero_grad() - loss = criterion(mel_out, mel_spec, mel_lengths) + loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, + mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - # ignore pre-higway layer since it works conditional + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional if count not in [148, 59]: - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) count += 1 - - - \ No newline at end of file diff --git a/train.py b/train.py index 5b73e484..367ffff2 100644 --- a/train.py +++ b/train.py @@ -1,39 +1,34 @@ import os import sys import time -import datetime import shutil import torch -import signal import argparse import importlib -import pickle import traceback import numpy as np import torch.nn as nn from torch import optim -from torch import onnx from torch.utils.data import DataLoader -from torch.optim.lr_scheduler import ReduceLROnPlateau from tensorboardX import SummaryWriter -from utils.generic_utils import (synthesis, remove_experiment_folder, - create_experiment_folder, save_checkpoint, - save_best_model, load_config, lr_decay, - count_parameters, check_update, get_commit_hash) +from utils.generic_utils import ( + synthesis, remove_experiment_folder, create_experiment_folder, + save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, + check_update, get_commit_hash) from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor - torch.manual_seed(1) torch.set_num_threads(4) use_cuda = torch.cuda.is_available() -def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap, epoch): +def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, + ap, epoch): model = model.train() epoch_time = 0 avg_linear_loss = 0 @@ -54,7 +49,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, stop_targets = data[5] # set stop targets view, we predict a single stop token per r frames prediction - stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(text_input.shape[0], + stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() current_step = num_iter + args.restore_step + \ @@ -89,7 +85,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # loss computation stop_loss = criterion_st(stop_tokens, stop_targets) mel_loss = criterion(mel_output, mel_input, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq], mel_lengths) @@ -106,7 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, # backpass and check the grad norm for stop loss stop_loss.backward() - grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100) + grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, + 0.5, 100) if skip_flag: optimizer_st.zero_grad() print(" | | > Iteration skipped fro stopnet!!") @@ -117,9 +114,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch_time += step_time if current_step % c.print_step == 0: - print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\ - "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\ - "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step, + print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " + "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " + "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, + current_step, loss.item(), linear_loss.item(), mel_loss.item(), @@ -147,8 +145,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, if current_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(), - OUT_PATH, current_step, epoch) + save_checkpoint(model, optimizer, optimizer_st, + linear_loss.item(), OUT_PATH, current_step, + epoch) # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() @@ -168,8 +167,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap.griffin_lim_iters = 60 audio_signal = ap.inv_spectrogram(audio_signal.T) try: - tb.add_audio('SampleAudio', audio_signal, current_step, - sample_rate=c.sample_rate) + tb.add_audio( + 'SampleAudio', + audio_signal, + current_step, + sample_rate=c.sample_rate) except: pass @@ -180,9 +182,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_step_time /= (num_iter + 1) # print epoch stats - print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\ - "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\ - "AvgStopLoss:{:.5f} EpochTime:{:.2f} "\ + print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " + "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " + "AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStepTime:{:.2f}".format(current_step, avg_total_loss, avg_linear_loss, @@ -209,10 +211,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): avg_mel_loss = 0 avg_stop_loss = 0 print(" | > Validation") - test_sentences = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist."] + test_sentences = [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist." + ] n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) with torch.no_grad(): if data_loader is not None: @@ -228,7 +232,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): stop_targets = data[5] # set stop targets view, we predict a single stop token per r frames prediction - stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(text_input.shape[0], + stop_targets.size(1) // c.r, + -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() # dispatch data to GPU @@ -256,11 +262,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): epoch_time += step_time if num_iter % c.print_step == 0: - print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\ - "StopLoss: {:.5f} ".format(loss.item(), - linear_loss.item(), - mel_loss.item(), - stop_loss.item()), flush=True) + print( + " | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} " + "StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(), + mel_loss.item(), stop_loss.item()), + flush=True) avg_linear_loss += linear_loss.item() avg_mel_loss += mel_loss.item() @@ -278,15 +284,19 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): tb.add_image('ValVisual/Reconstruction', const_spec, current_step) tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) - tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) + tb.add_image('ValVisual/ValidationAlignment', align_img, + current_step) # Sample audio audio_signal = linear_output[idx].data.cpu().numpy() ap.griffin_lim_iters = 60 audio_signal = ap.inv_spectrogram(audio_signal.T) try: - tb.add_audio('ValSampleAudio', audio_signal, current_step, - sample_rate=c.sample_rate) + tb.add_audio( + 'ValSampleAudio', + audio_signal, + current_step, + sample_rate=c.sample_rate) except: # sometimes audio signal is out of boundaries pass @@ -298,81 +308,88 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss # Plot Learning Stats - tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) - tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) + tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, + current_step) + tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, + current_step) tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) - tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) + tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, + current_step) # test sentences ap.griffin_lim_iters = 60 for idx, test_sentence in enumerate(test_sentences): - wav, linear_spec, alignments = synthesis(model, ap, test_sentence, use_cuda, - c.text_cleaner) + wav, linear_spec, alignments = synthesis(model, ap, test_sentence, + use_cuda, c.text_cleaner) try: wav_name = 'TestSentences/{}'.format(idx) - tb.add_audio(wav_name, wav, current_step, - sample_rate=c.sample_rate) + tb.add_audio( + wav_name, wav, current_step, sample_rate=c.sample_rate) except: pass align_img = alignments[0].data.cpu().numpy() linear_spec = plot_spectrogram(linear_spec, ap) align_img = plot_alignment(align_img) - tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step) - tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, current_step) + tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, + current_step) + tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, + current_step) return avg_linear_loss def main(args): - dataset = importlib.import_module('datasets.'+c.dataset) + dataset = importlib.import_module('datasets.' + c.dataset) Dataset = getattr(dataset, 'MyDataset') - audio = importlib.import_module('utils.'+c.audio_processor) + audio = importlib.import_module('utils.' + c.audio_processor) AudioProcessor = getattr(audio, 'AudioProcessor') - ap = AudioProcessor(sample_rate=c.sample_rate, - num_mels=c.num_mels, - min_level_db=c.min_level_db, - frame_shift_ms=c.frame_shift_ms, - frame_length_ms=c.frame_length_ms, - ref_level_db=c.ref_level_db, - num_freq=c.num_freq, - power=c.power, - preemphasis=c.preemphasis, - min_mel_freq=c.min_mel_freq, - max_mel_freq=c.max_mel_freq) + ap = AudioProcessor( + sample_rate=c.sample_rate, + num_mels=c.num_mels, + min_level_db=c.min_level_db, + frame_shift_ms=c.frame_shift_ms, + frame_length_ms=c.frame_length_ms, + ref_level_db=c.ref_level_db, + num_freq=c.num_freq, + power=c.power, + preemphasis=c.preemphasis, + min_mel_freq=c.min_mel_freq, + max_mel_freq=c.max_mel_freq) # Setup the dataset - train_dataset = Dataset(c.data_path, - c.meta_file_train, - c.r, - c.text_cleaner, - ap = ap, - min_seq_len=c.min_seq_len - ) + train_dataset = Dataset( + c.data_path, + c.meta_file_train, + c.r, + c.text_cleaner, + ap=ap, + min_seq_len=c.min_seq_len) - train_loader = DataLoader(train_dataset, batch_size=c.batch_size, - shuffle=False, collate_fn=train_dataset.collate_fn, - drop_last=False, num_workers=c.num_loader_workers, - pin_memory=True) + train_loader = DataLoader( + train_dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=train_dataset.collate_fn, + drop_last=False, + num_workers=c.num_loader_workers, + pin_memory=True) if c.run_eval: - val_dataset = Dataset(c.data_path, - c.meta_file_val, - c.r, - c.text_cleaner, - ap = ap - ) + val_dataset = Dataset( + c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap) - val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, - shuffle=False, collate_fn=val_dataset.collate_fn, - drop_last=False, num_workers=4, - pin_memory=True) + val_loader = DataLoader( + val_dataset, + batch_size=c.eval_batch_size, + shuffle=False, + collate_fn=val_dataset.collate_fn, + drop_last=False, + num_workers=4, + pin_memory=True) else: val_loader = None - model = Tacotron(c.embedding_size, - ap.num_freq, - c.num_mels, - c.r) + model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r) print(" | > Num output units : {}".format(ap.num_freq), flush=True) optimizer = optim.Adam(model.parameters(), lr=c.lr) @@ -394,7 +411,8 @@ def main(args): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() - print(" > Model restored from step %d" % checkpoint['step'], flush=True) + print( + " > Model restored from step %d" % checkpoint['step'], flush=True) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] args.restore_step = checkpoint['step'] @@ -416,22 +434,36 @@ def main(args): best_loss = float('inf') for epoch in range(0, c.epochs): - train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, ap, epoch) - val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step) - print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss), flush=True) - best_loss = save_best_model(model, optimizer, train_loss, - best_loss, OUT_PATH, - current_step, epoch) + train_loss, current_step = train(model, criterion, criterion_st, + train_loader, optimizer, optimizer_st, + ap, epoch) + val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, + current_step) + print( + " | > Train Loss: {:.5f} Validation Loss: {:.5f}".format( + train_loss, val_loss), + flush=True) + best_loss = save_best_model(model, optimizer, train_loss, best_loss, + OUT_PATH, current_step, epoch) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--restore_path', type=str, - help='Folder path to checkpoints', default=0) - parser.add_argument('--config_path', type=str, - help='path to config file for training',) - parser.add_argument('--debug', type=bool, default=False, - help='do not ask for git has before run.') + parser.add_argument( + '--restore_path', + type=str, + help='Folder path to checkpoints', + default=0) + parser.add_argument( + '--config_path', + type=str, + help='path to config file for training', + ) + parser.add_argument( + '--debug', + type=bool, + default=False, + help='do not ask for git has before run.') args = parser.parse_args() # setup output paths and read configs diff --git a/utils/audio.py b/utils/audio.py index 42e3f0a2..1e496211 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -9,10 +9,19 @@ _mel_basis = None class AudioProcessor(object): - - def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, ref_level_db, num_freq, power, preemphasis, - min_mel_freq, max_mel_freq, griffin_lim_iters=None): + def __init__(self, + sample_rate, + num_mels, + min_level_db, + frame_shift_ms, + frame_length_ms, + ref_level_db, + num_freq, + power, + preemphasis, + min_mel_freq, + max_mel_freq, + griffin_lim_iters=None): self.sample_rate = sample_rate self.num_mels = num_mels @@ -30,7 +39,8 @@ class AudioProcessor(object): def save_wav(self, wav, path): wav *= 32767 / max(0.01, np.max(np.abs(wav))) - librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) + librosa.output.write_wav( + path, wav.astype(np.float), self.sample_rate, norm=True) def _linear_to_mel(self, spectrogram): global _mel_basis @@ -40,8 +50,9 @@ class AudioProcessor(object): def _build_mel_basis(self, ): n_fft = (self.num_freq - 1) * 2 - return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) - # fmin=self.min_mel_freq, fmax=self.max_mel_freq) + return librosa.filters.mel( + self.sample_rate, n_fft, n_mels=self.num_mels) + # fmin=self.min_mel_freq, fmax=self.max_mel_freq) def _normalize(self, S): return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) @@ -66,7 +77,7 @@ class AudioProcessor(object): if self.preemphasis == 0: raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") return signal.lfilter([1, -self.preemphasis], [1], x) - + def apply_inv_preemphasis(self, x): if self.preemphasis == 0: raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") @@ -86,9 +97,9 @@ class AudioProcessor(object): S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear # Reconstruct phase if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) else: - return self._griffin_lim(S ** self.power) + return self._griffin_lim(S**self.power) def _griffin_lim(self, S): '''Applies Griffin-Lim's raw. @@ -113,7 +124,8 @@ class AudioProcessor(object): def _stft(self, y): n_fft, hop_length, win_length = self._stft_parameters() - return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) + return librosa.stft( + y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) def _istft(self, y): _, hop_length, win_length = self._stft_parameters() diff --git a/utils/audio_lws.py b/utils/audio_lws.py index 30c819b3..d3bda848 100644 --- a/utils/audio_lws.py +++ b/utils/audio_lws.py @@ -8,11 +8,23 @@ import lws _mel_basis = None -class AudioProcessor(object): - def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, ref_level_db, num_freq, power, preemphasis, - min_mel_freq, max_mel_freq, griffin_lim_iters=None, ): +class AudioProcessor(object): + def __init__( + self, + sample_rate, + num_mels, + min_level_db, + frame_shift_ms, + frame_length_ms, + ref_level_db, + num_freq, + power, + preemphasis, + min_mel_freq, + max_mel_freq, + griffin_lim_iters=None, + ): print(" > Setting up Audio Processor...") self.sample_rate = sample_rate self.num_mels = num_mels @@ -25,18 +37,19 @@ class AudioProcessor(object): self.min_mel_freq = min_mel_freq self.max_mel_freq = max_mel_freq self.griffin_lim_iters = griffin_lim_iters - self.preemphasis =preemphasis + self.preemphasis = preemphasis self.n_fft, self.hop_length, self.win_length = self._stft_parameters() if preemphasis == 0: print(" | > Preemphasis is deactive.") def save_wav(self, wav, path): wav *= 32767 / max(0.01, np.max(np.abs(wav))) - librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) - + librosa.output.write_wav( + path, wav.astype(np.float), self.sample_rate, norm=True) + def _stft_parameters(self, ): n_fft = int((self.num_freq - 1) * 2) - hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) if n_fft % hop_length != 0: hop_length = n_fft / 8 @@ -44,14 +57,21 @@ class AudioProcessor(object): if n_fft % win_length != 0: win_length = n_fft / 2 print(" | > win_length is set to default ({}).".format(win_length)) - print(" | > fft size: {}, hop length: {}, win length: {}".format(n_fft, hop_length, win_length)) + print(" | > fft size: {}, hop length: {}, win length: {}".format( + n_fft, hop_length, win_length)) return int(n_fft), int(hop_length), int(win_length) - + def _lws_processor(self): try: - return lws.lws(self.win_length, self.hop_length, fftsize=self.n_fft, mode="speech") + return lws.lws( + self.win_length, + self.hop_length, + fftsize=self.n_fft, + mode="speech") except: - raise RuntimeError(" !! WindowLength({}) is not multiple of HopLength({}).".format(self.win_length, self.hop_length)) + raise RuntimeError( + " !! WindowLength({}) is not multiple of HopLength({}).". + format(self.win_length, self.hop_length)) def _amp_to_db(self, x): min_level = np.exp(self.min_level_db / 20 * np.log(10)) @@ -70,7 +90,7 @@ class AudioProcessor(object): if self.preemphasis == 0: raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") return signal.lfilter([1, -self.preemphasis], [1], x) - + def apply_inv_preemphasis(self, x): if self.preemphasis == 0: raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ") @@ -96,14 +116,14 @@ class AudioProcessor(object): S = self._denormalize(spectrogram) S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear processor = self._lws_processor() - D = processor.run_lws(S.astype(np.float64).T ** self.power) + D = processor.run_lws(S.astype(np.float64).T**self.power) y = processor.istft(D).astype(np.float32) # Reconstruct phase if self.preemphasis: return self.apply_inv_preemphasis(y) sys.stdout = old_out return y - + def _linear_to_mel(self, spectrogram): global _mel_basis if _mel_basis is None: @@ -111,7 +131,10 @@ class AudioProcessor(object): return np.dot(_mel_basis, spectrogram) def _build_mel_basis(self, ): - return librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.num_mels) + return librosa.filters.mel( + self.sample_rate, self.n_fft, n_mels=self.num_mels) + + # fmin=self.min_mel_freq, fmax=self.max_mel_freq) def melspectrogram(self, y): @@ -124,4 +147,4 @@ class AudioProcessor(object): D = self._lws_processor().stft(y).T S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db sys.stdout = old_out - return self._normalize(S) \ No newline at end of file + return self._normalize(S) diff --git a/utils/data.py b/utils/data.py index 51d8acb1..f7f1d0ee 100644 --- a/utils/data.py +++ b/utils/data.py @@ -4,9 +4,8 @@ import numpy as np def _pad_data(x, length): _pad = 0 assert x.ndim == 1 - return np.pad(x, (0, length - x.shape[0]), - mode='constant', - constant_values=_pad) + return np.pad( + x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) def prepare_data(inputs): @@ -17,8 +16,10 @@ def prepare_data(inputs): def _pad_tensor(x, length): _pad = 0 assert x.ndim == 2 - x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], - mode='constant', constant_values=_pad) + x = np.pad( + x, [[0, 0], [0, length - x.shape[1]]], + mode='constant', + constant_values=_pad) return x @@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): _pad = 1. assert x.ndim == 1 - return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + return np.pad( + x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) def prepare_stop_target(inputs, out_steps): @@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps): def pad_per_step(inputs, pad_len): timesteps = inputs.shape[-1] - return np.pad(inputs, [[0, 0], [0, 0], - [0, pad_len]], - mode='constant', constant_values=0.0) + return np.pad( + inputs, [[0, 0], [0, 0], [0, pad_len]], + mode='constant', + constant_values=0.0) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index d62b648a..c9a84fb6 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -28,10 +28,13 @@ def load_config(config_path): def get_commit_hash(): """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" try: - subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean + subprocess.check_output(['git', 'diff-index', '--quiet', + 'HEAD']) # Verify client is clean except: - raise RuntimeError(" !! Commit before training to get the commit hash.") - commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip() + raise RuntimeError( + " !! Commit before training to get the commit hash.") + commit = subprocess.check_output(['git', 'rev-parse', '--short', + 'HEAD']).decode().strip() print(' > Git Hash: {}'.format(commit)) return commit @@ -43,7 +46,8 @@ def create_experiment_folder(root_path, model_name, debug): commit_hash = 'debug' else: commit_hash = get_commit_hash() - output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash) + output_folder = os.path.join( + root_path, date_str + '-' + model_name + '-' + commit_hash) os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) return output_folder @@ -52,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug): def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" - checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") + checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") if len(checkpoint_files) < 1: if os.path.exists(experiment_path): shutil.rmtree(experiment_path) @@ -86,13 +90,15 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path, print(" | | > Checkpoint saving : {}".format(checkpoint_path)) new_state_dict = _trim_model_state_dict(model.state_dict()) - state = {'model': new_state_dict, - 'optimizer': optimizer.state_dict(), - 'optimizer_st': optimizer_st.state_dict(), - 'step': current_step, - 'epoch': epoch, - 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y")} + state = { + 'model': new_state_dict, + 'optimizer': optimizer.state_dict(), + 'optimizer_st': optimizer_st.state_dict(), + 'step': current_step, + 'epoch': epoch, + 'linear_loss': model_loss, + 'date': datetime.date.today().strftime("%B %d, %Y") + } torch.save(state, checkpoint_path) @@ -100,12 +106,14 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step, epoch): if model_loss < best_loss: new_state_dict = _trim_model_state_dict(model.state_dict()) - state = {'model': new_state_dict, - 'optimizer': optimizer.state_dict(), - 'step': current_step, - 'epoch': epoch, - 'linear_loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y")} + state = { + 'model': new_state_dict, + 'optimizer': optimizer.state_dict(), + 'step': current_step, + 'epoch': epoch, + 'linear_loss': model_loss, + 'date': datetime.date.today().strftime("%B %d, %Y") + } best_loss = model_loss bestmodel_path = 'best_model.pth.tar' bestmodel_path = os.path.join(out_path, bestmodel_path) @@ -161,12 +169,12 @@ def sequence_mask(sequence_length, max_len=None): def synthesis(model, ap, text, use_cuda, text_cleaner): - text_cleaner = [text_cleaner] - seq = np.array(text_to_sequence(text, text_cleaner)) - chars_var = torch.from_numpy(seq).unsqueeze(0) - if use_cuda: - chars_var = chars_var.cuda().long() - _, linear_out, alignments, _ = model.forward(chars_var) - linear_out = linear_out[0].data.cpu().numpy() - wav = ap.inv_spectrogram(linear_out.T) - return wav, linear_out, alignments \ No newline at end of file + text_cleaner = [text_cleaner] + seq = np.array(text_to_sequence(text, text_cleaner)) + chars_var = torch.from_numpy(seq).unsqueeze(0) + if use_cuda: + chars_var = chars_var.cuda().long() + _, linear_out, alignments, _ = model.forward(chars_var) + linear_out = linear_out[0].data.cpu().numpy() + wav = ap.inv_spectrogram(linear_out.T) + return wav, linear_out, alignments diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 3d158c99..37716fa9 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -4,7 +4,6 @@ import re from utils.text import cleaners from utils.text.symbols import symbols - # Mappings from symbol to numeric ID and vice versa: _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index 57f39ba6..31c04ae4 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -14,31 +14,31 @@ import re from unidecode import unidecode from .numbers import normalize_numbers - # Regular expression matching whitespace: _whitespace_re = re.compile(r'\s+') # List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) + for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), + ]] def expand_abbreviations(text): diff --git a/utils/text/cmudict.py b/utils/text/cmudict.py index 59bd7a73..291ad33f 100644 --- a/utils/text/cmudict.py +++ b/utils/text/cmudict.py @@ -1,17 +1,16 @@ # -*- coding: utf-8 -*- - import re - valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', + 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', + 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', + 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', + 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', + 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', + 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', + 'Y', 'Z', 'ZH' ] _valid_symbol_set = set(valid_symbols) @@ -27,8 +26,10 @@ class CMUDict: else: entries = _parse_cmudict(file_or_path) if not keep_ambiguous: - entries = {word: pron for word, - pron in entries.items() if len(pron) == 1} + entries = { + word: pron + for word, pron in entries.items() if len(pron) == 1 + } self._entries = entries def __len__(self): diff --git a/utils/text/numbers.py b/utils/text/numbers.py index 74484ad3..9cc6f4df 100644 --- a/utils/text/numbers.py +++ b/utils/text/numbers.py @@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)') _number_re = re.compile(r'[0-9]+') _units = [ - '', - 'one', - 'two', - 'three', - 'four', - 'five', - 'six', - 'seven', - 'eight', - 'nine', - 'ten', - 'eleven', - 'twelve', - 'thirteen', - 'fourteen', - 'fifteen', - 'sixteen', - 'seventeen', - 'eighteen', - 'nineteen' + '', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', + 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen', + 'seventeen', 'eighteen', 'nineteen' ] _tens = [ - '', - 'ten', - 'twenty', - 'thirty', - 'forty', - 'fifty', - 'sixty', - 'seventy', - 'eighty', - 'ninety', + '', + 'ten', + 'twenty', + 'thirty', + 'forty', + 'fifty', + 'sixty', + 'seventy', + 'eighty', + 'ninety', ] _digit_groups = [ - '', - 'thousand', - 'million', - 'billion', - 'trillion', - 'quadrillion', + '', + 'thousand', + 'million', + 'billion', + 'trillion', + 'quadrillion', ] _ordinal_suffixes = [ - ('one', 'first'), - ('two', 'second'), - ('three', 'third'), - ('five', 'fifth'), - ('eight', 'eighth'), - ('nine', 'ninth'), - ('twelve', 'twelfth'), - ('ty', 'tieth'), + ('one', 'first'), + ('two', 'second'), + ('three', 'third'), + ('five', 'fifth'), + ('eight', 'eighth'), + ('nine', 'ninth'), + ('twelve', 'twelfth'), + ('ty', 'tieth'), ] + def _remove_commas(m): return m.group(1).replace(',', '') @@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group): def _number_to_words(n): # Handle special cases first, then go to the standard case: if n >= 1000000000000000000: - return str(n) # Too large, just return the digits + return str(n) # Too large, just return the digits elif n == 0: return 'zero' elif n % 100 == 0 and n % 1000 != 0 and n < 3000: diff --git a/utils/text/symbols.py b/utils/text/symbols.py index c8550e1d..4c8f6c43 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- - - ''' Defines the set of symbols used in text input to the model. @@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols] # Export all symbols: symbols = [_pad, _eos] + list(_characters) + _arpabet - if __name__ == '__main__': print(symbols) diff --git a/utils/visual.py b/utils/visual.py index 1cf50f5d..862e7e5c 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -6,8 +6,8 @@ import matplotlib.pyplot as plt def plot_alignment(alignment, info=None): fig, ax = plt.subplots(figsize=(16, 10)) - im = ax.imshow(alignment.T, aspect='auto', origin='lower', - interpolation='none') + im = ax.imshow( + alignment.T, aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: @@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None): plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) plt.close() return data @@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio): plt.tight_layout() fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) plt.close() - return data \ No newline at end of file + return data