From 91e5f8b63dc9f1ba3eed123ce462b1000569e8b1 Mon Sep 17 00:00:00 2001 From: sanjaesc Date: Thu, 22 Oct 2020 10:44:00 +0200 Subject: [PATCH] added to device cpu/gpu + formatting --- TTS/bin/train_wavernn_vocoder.py | 182 ++++++++++++------------ TTS/vocoder/datasets/wavernn_dataset.py | 34 ++--- TTS/vocoder/models/wavernn.py | 66 +++++---- 3 files changed, 145 insertions(+), 137 deletions(-) diff --git a/TTS/bin/train_wavernn_vocoder.py b/TTS/bin/train_wavernn_vocoder.py index 78984510..66a7c913 100644 --- a/TTS/bin/train_wavernn_vocoder.py +++ b/TTS/bin/train_wavernn_vocoder.py @@ -44,43 +44,41 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val and not CONFIG.run_eval: loader = None else: - dataset = WaveRNNDataset( - ap=ap, - items=eval_data if is_val else train_data, - seq_len=CONFIG.seq_len, - hop_len=ap.hop_length, - pad=CONFIG.padding, - mode=CONFIG.mode, - is_training=not is_val, - verbose=verbose, - ) + dataset = WaveRNNDataset(ap=ap, + items=eval_data if is_val else train_data, + seq_len=CONFIG.seq_len, + hop_len=ap.hop_length, + pad=CONFIG.padding, + mode=CONFIG.mode, + is_training=not is_val, + verbose=verbose, + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - shuffle=True, - collate_fn=dataset.collate, - batch_size=CONFIG.batch_size, - num_workers=CONFIG.num_val_loader_workers - if is_val - else CONFIG.num_loader_workers, - pin_memory=True, - ) + loader = DataLoader(dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=CONFIG.batch_size, + num_workers=CONFIG.num_val_loader_workers + if is_val + else CONFIG.num_loader_workers, + pin_memory=True, + ) return loader def format_data(data): # setup input data - x = data[0] - m = data[1] - y = data[2] + x_input = data[0] + mels = data[1] + y_coarse = data[2] # dispatch data to GPU if use_cuda: - x = x.cuda(non_blocking=True) - m = m.cuda(non_blocking=True) - y = y.cuda(non_blocking=True) + x_input = x_input.cuda(non_blocking=True) + mels = mels.cuda(non_blocking=True) + y_coarse = y_coarse.cuda(non_blocking=True) - return x, m, y + return x_input, mels, y_coarse def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): @@ -90,7 +88,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / + (CONFIG.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size) end_time = time.time() @@ -99,30 +98,31 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): print(" > Training", flush=True) for num_iter, data in enumerate(data_loader): start_time = time.time() - x, m, y = format_data(data) + x_input, mels, y_coarse = format_data(data) loader_time = time.time() - end_time global_step += 1 ################## # MODEL TRAINING # ################## - y_hat = model(x, m) + y_hat = model(x_input, mels) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: - y = y.float() - y = y.unsqueeze(-1) + y_coarse = y_coarse.float() + y_coarse = y_coarse.unsqueeze(-1) # m_scaled, _ = model.upsample(m) # compute losses - loss = criterion(y_hat, y) + loss = criterion(y_hat, y_coarse) if loss.item() is None: raise RuntimeError(" [!] None loss. Exiting ...") optimizer.zero_grad() loss.backward() if CONFIG.grad_clip > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip) + torch.nn.utils.clip_grad_norm_( + model.parameters(), CONFIG.grad_clip) optimizer.step() if scheduler is not None: @@ -145,19 +145,17 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): # print training stats if global_step % CONFIG.print_step == 0: - log_dict = { - "step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr": cur_lr, - } - c_logger.print_train_step( - batch_n_iter, - num_iter, - global_step, - log_dict, - loss_dict, - keep_avg.avg_values, - ) + log_dict = {"step_time": [step_time, 2], + "loader_time": [loader_time, 4], + "current_lr": cur_lr, + } + c_logger.print_train_step(batch_n_iter, + num_iter, + global_step, + log_dict, + loss_dict, + keep_avg.avg_values, + ) # plot step stats if global_step % 10 == 0: @@ -169,40 +167,38 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): if global_step % CONFIG.save_step == 0: if CONFIG.checkpoint: # save model - save_checkpoint( - model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - ) + save_checkpoint(model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + ) # synthesize a full voice wav_path = train_data[random.randrange(0, len(train_data))][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) - sample_wav = model.generate( - ground_mel, - CONFIG.batched, - CONFIG.target_samples, - CONFIG.overlap_samples, - ) + sample_wav = model.generate(ground_mel, + CONFIG.batched, + CONFIG.target_samples, + CONFIG.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms - figures = { - "train/ground_truth": plot_spectrogram(ground_mel.T), - "train/prediction": plot_spectrogram(predict_mel.T), - } + figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), + "train/prediction": plot_spectrogram(predict_mel.T), + } # Sample audio tb_logger.tb_train_audios( - global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"] + global_step, { + "train/audio": sample_wav}, CONFIG.audio["sample_rate"] ) tb_logger.tb_train_figures(global_step, figures) @@ -234,17 +230,17 @@ def evaluate(model, criterion, ap, global_step, epoch): for num_iter, data in enumerate(data_loader): start_time = time.time() # format data - x, m, y = format_data(data) + x_input, mels, y_coarse = format_data(data) loader_time = time.time() - end_time global_step += 1 - y_hat = model(x, m) + y_hat = model(x_input, mels) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: - y = y.float() - y = y.unsqueeze(-1) - loss = criterion(y_hat, y) + y_coarse = y_coarse.float() + y_coarse = y_coarse.unsqueeze(-1) + loss = criterion(y_hat, y_coarse) # Compute avg loss # if num_gpus > 1: # loss = reduce_tensor(loss.data, num_gpus) @@ -264,30 +260,31 @@ def evaluate(model, criterion, ap, global_step, epoch): # print eval stats if CONFIG.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) + c_logger.print_eval_step( + num_iter, loss_dict, keep_avg.avg_values) - if epoch % CONFIG.test_every_epochs == 0: + if epoch % CONFIG.test_every_epochs == 0 and epoch != 0: # synthesize a part of data wav_path = eval_data[random.randrange(0, len(eval_data))][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav[:22000]) - sample_wav = model.generate( - ground_mel, - CONFIG.batched, - CONFIG.target_samples, - CONFIG.overlap_samples, - ) + sample_wav = model.generate(ground_mel, + CONFIG.batched, + CONFIG.target_samples, + CONFIG.overlap_samples, + use_cuda + ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms - figures = { - "eval/ground_truth": plot_spectrogram(ground_mel.T), - "eval/prediction": plot_spectrogram(predict_mel.T), - } + figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T), + "eval/prediction": plot_spectrogram(predict_mel.T), + } # Sample audio tb_logger.tb_eval_audios( - global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"] + global_step, { + "eval/audio": sample_wav}, CONFIG.audio["sample_rate"] ) tb_logger.tb_eval_figures(global_step, figures) @@ -372,7 +369,8 @@ def main(args): # pylint: disable=redefined-outer-name model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG) model_wavernn.load_state_dict(model_dict) - print(" > Model restored from step %d" % checkpoint["step"], flush=True) + print(" > Model restored from step %d" % + checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -393,7 +391,8 @@ def main(args): # pylint: disable=redefined-outer-name _, global_step = train( model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch ) - eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) + eval_avg_loss_dict = evaluate( + model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( @@ -493,7 +492,8 @@ if __name__ == "__main__": tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") # write model desc to tensorboard - tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0) + tb_logger.tb_add_text("model-description", + CONFIG["run_description"], 0) try: main(args) diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index 5d5b9f15..194344a9 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -8,17 +8,16 @@ class WaveRNNDataset(Dataset): WaveRNN Dataset searchs for all the wav files under root path. """ - def __init__( - self, - ap, - items, - seq_len, - hop_len, - pad, - mode, - is_training=True, - verbose=False, - ): + def __init__(self, + ap, + items, + seq_len, + hop_len, + pad, + mode, + is_training=True, + verbose=False, + ): self.ap = ap self.item_list = items @@ -56,17 +55,19 @@ class WaveRNNDataset(Dataset): def collate(self, batch): mel_win = self.seq_len // self.hop_len + 2 * self.pad - max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] + max_offsets = [x[0].shape[-1] - + (mel_win + 2 * self.pad) for x in batch] mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] - sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] + sig_offsets = [(offset + self.pad) * + self.hop_len for offset in mel_offsets] mels = [ - x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] + x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win] for i, x in enumerate(batch) ] coarse = [ - x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] + x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch) ] @@ -79,7 +80,8 @@ class WaveRNNDataset(Dataset): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) x_input = ( - 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 + 2 * coarse[:, : self.seq_len].float() / + (2 ** self.mode - 1.0) - 1.0 ) y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 4d1a633c..9b151cac 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -39,7 +39,8 @@ class MelResNet(nn.Module): def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): super().__init__() k_size = pad * 2 + 1 - self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.conv_in = nn.Conv1d( + in_dims, compute_dims, kernel_size=k_size, bias=False) self.batch_norm = nn.BatchNorm1d(compute_dims) self.layers = nn.ModuleList() for _ in range(res_blocks): @@ -94,7 +95,8 @@ class UpsampleNetwork(nn.Module): k_size = (1, scale * 2 + 1) padding = (0, scale) stretch = Stretch2d(scale, 1) - conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv = nn.Conv2d(1, 1, kernel_size=k_size, + padding=padding, bias=False) conv.weight.data.fill_(1.0 / k_size[1]) self.up_layers.append(stretch) self.up_layers.append(conv) @@ -110,7 +112,7 @@ class UpsampleNetwork(nn.Module): m = m.unsqueeze(1) for f in self.up_layers: m = f(m) - m = m.squeeze(1)[:, :, self.indent : -self.indent] + m = m.squeeze(1)[:, :, self.indent: -self.indent] return m.transpose(1, 2), aux @@ -123,7 +125,8 @@ class Upsample(nn.Module): self.pad = pad self.indent = pad * scale self.use_aux_net = use_aux_net - self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet = MelResNet(res_blocks, feat_dims, + compute_dims, res_out_dims, pad) def forward(self, m): if self.use_aux_net: @@ -137,7 +140,7 @@ class Upsample(nn.Module): m = torch.nn.functional.interpolate( m, scale_factor=self.scale, mode="linear", align_corners=True ) - m = m[:, :, self.indent : -self.indent] + m = m[:, :, self.indent: -self.indent] m = m * 0.045 # empirically found return m.transpose(1, 2), aux @@ -207,7 +210,8 @@ class WaveRNN(nn.Module): if self.use_aux_net: self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, + rnn_dims, batch_first=True) self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) self.fc3 = nn.Linear(fc_dims, self.n_classes) @@ -221,16 +225,16 @@ class WaveRNN(nn.Module): def forward(self, x, mels): bsize = x.size(0) - h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() - h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() + h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device) mels, aux = self.upsample(mels) if self.use_aux_net: aux_idx = [self.aux_dims * i for i in range(5)] - a1 = aux[:, :, aux_idx[0] : aux_idx[1]] - a2 = aux[:, :, aux_idx[1] : aux_idx[2]] - a3 = aux[:, :, aux_idx[2] : aux_idx[3]] - a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + a1 = aux[:, :, aux_idx[0]: aux_idx[1]] + a2 = aux[:, :, aux_idx[1]: aux_idx[2]] + a3 = aux[:, :, aux_idx[2]: aux_idx[3]] + a4 = aux[:, :, aux_idx[3]: aux_idx[4]] x = ( torch.cat([x.unsqueeze(-1), mels, a1], dim=2) @@ -256,19 +260,21 @@ class WaveRNN(nn.Module): x = F.relu(self.fc2(x)) return self.fc3(x) - def generate(self, mels, batched, target, overlap): + def generate(self, mels, batched, target, overlap, use_cuda): self.eval() + device = 'cuda' if use_cuda else 'cpu' output = [] start = time.time() rnn1 = self.get_gru_cell(self.rnn1) rnn2 = self.get_gru_cell(self.rnn2) with torch.no_grad(): - - mels = torch.FloatTensor(mels).cuda().unsqueeze(0) + mels = torch.FloatTensor(mels).unsqueeze(0).to(device) + #mels = torch.FloatTensor(mels).cuda().unsqueeze(0) wave_len = (mels.size(-1) - 1) * self.hop_length - mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both") + mels = self.pad_tensor(mels.transpose( + 1, 2), pad=self.pad, side="both") mels, aux = self.upsample(mels.transpose(1, 2)) if batched: @@ -278,13 +284,13 @@ class WaveRNN(nn.Module): b_size, seq_len, _ = mels.size() - h1 = torch.zeros(b_size, self.rnn_dims).cuda() - h2 = torch.zeros(b_size, self.rnn_dims).cuda() - x = torch.zeros(b_size, 1).cuda() + h1 = torch.zeros(b_size, self.rnn_dims).to(device) + h2 = torch.zeros(b_size, self.rnn_dims).to(device) + x = torch.zeros(b_size, 1).to(device) if self.use_aux_net: d = self.aux_dims - aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] + aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)] for i in range(seq_len): @@ -319,11 +325,12 @@ class WaveRNN(nn.Module): logits.unsqueeze(0).transpose(1, 2) ) output.append(sample.view(-1)) - x = sample.transpose(0, 1).cuda() + x = sample.transpose(0, 1).to(device) elif self.mode == "gauss": - sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) + sample = sample_from_gaussian( + logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) - x = sample.transpose(0, 1).cuda() + x = sample.transpose(0, 1).to(device) elif isinstance(self.mode, int): posterior = F.softmax(logits, dim=1) distrib = torch.distributions.Categorical(posterior) @@ -332,7 +339,8 @@ class WaveRNN(nn.Module): output.append(sample) x = sample.unsqueeze(-1) else: - raise RuntimeError("Unknown model mode value - ", self.mode) + raise RuntimeError( + "Unknown model mode value - ", self.mode) if i % 100 == 0: self.gen_display(i, seq_len, b_size, start) @@ -352,7 +360,7 @@ class WaveRNN(nn.Module): # Fade-out at the end to avoid signal cutting out suddenly fade_out = np.linspace(1, 0, 20 * self.hop_length) output = output[:wave_len] - output[-20 * self.hop_length :] *= fade_out + output[-20 * self.hop_length:] *= fade_out self.train() return output @@ -366,7 +374,6 @@ class WaveRNN(nn.Module): ) def fold_with_overlap(self, x, target, overlap): - """Fold the tensor with overlap for quick batched inference. Overlap will be used for crossfading in xfade_and_unfold() Args: @@ -398,7 +405,7 @@ class WaveRNN(nn.Module): padding = target + 2 * overlap - remaining x = self.pad_tensor(x, padding, side="after") - folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda() + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) # Get the values for the folded tensor for i in range(num_folds): @@ -423,16 +430,15 @@ class WaveRNN(nn.Module): # i.e., it won't generalise to other shapes/dims b, t, c = x.size() total = t + 2 * pad if side == "both" else t + pad - padded = torch.zeros(b, total, c).cuda() + padded = torch.zeros(b, total, c).to(x.device) if side in ("before", "both"): - padded[:, pad : pad + t, :] = x + padded[:, pad: pad + t, :] = x elif side == "after": padded[:, :t, :] = x return padded @staticmethod def xfade_and_unfold(y, target, overlap): - """Applies a crossfade and unfolds into a 1d array. Args: y (ndarry) : Batched sequences of audio samples