added to device cpu/gpu + formatting

pull/10/head
sanjaesc 2020-10-22 10:44:00 +02:00 committed by erogol
parent ea9d8755de
commit 91e5f8b63d
3 changed files with 145 additions and 137 deletions

View File

@ -44,43 +44,41 @@ def setup_loader(ap, is_val=False, verbose=False):
if is_val and not CONFIG.run_eval:
loader = None
else:
dataset = WaveRNNDataset(
ap=ap,
items=eval_data if is_val else train_data,
seq_len=CONFIG.seq_len,
hop_len=ap.hop_length,
pad=CONFIG.padding,
mode=CONFIG.mode,
is_training=not is_val,
verbose=verbose,
)
dataset = WaveRNNDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=CONFIG.seq_len,
hop_len=ap.hop_length,
pad=CONFIG.padding,
mode=CONFIG.mode,
is_training=not is_val,
verbose=verbose,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_val_loader_workers
if is_val
else CONFIG.num_loader_workers,
pin_memory=True,
)
loader = DataLoader(dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_val_loader_workers
if is_val
else CONFIG.num_loader_workers,
pin_memory=True,
)
return loader
def format_data(data):
# setup input data
x = data[0]
m = data[1]
y = data[2]
x_input = data[0]
mels = data[1]
y_coarse = data[2]
# dispatch data to GPU
if use_cuda:
x = x.cuda(non_blocking=True)
m = m.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
x_input = x_input.cuda(non_blocking=True)
mels = mels.cuda(non_blocking=True)
y_coarse = y_coarse.cuda(non_blocking=True)
return x, m, y
return x_input, mels, y_coarse
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
@ -90,7 +88,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) /
(CONFIG.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
end_time = time.time()
@ -99,30 +98,31 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
print(" > Training", flush=True)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
x, m, y = format_data(data)
x_input, mels, y_coarse = format_data(data)
loader_time = time.time() - end_time
global_step += 1
##################
# MODEL TRAINING #
##################
y_hat = model(x, m)
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y = y.float()
y = y.unsqueeze(-1)
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
# m_scaled, _ = model.upsample(m)
# compute losses
loss = criterion(y_hat, y)
loss = criterion(y_hat, y_coarse)
if loss.item() is None:
raise RuntimeError(" [!] None loss. Exiting ...")
optimizer.zero_grad()
loss.backward()
if CONFIG.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip)
torch.nn.utils.clip_grad_norm_(
model.parameters(), CONFIG.grad_clip)
optimizer.step()
if scheduler is not None:
@ -145,19 +145,17 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
# print training stats
if global_step % CONFIG.print_step == 0:
log_dict = {
"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
}
c_logger.print_train_step(
batch_n_iter,
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
log_dict = {"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
}
c_logger.print_train_step(batch_n_iter,
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
# plot step stats
if global_step % 10 == 0:
@ -169,40 +167,38 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
if global_step % CONFIG.save_step == 0:
if CONFIG.checkpoint:
# save model
save_checkpoint(
model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
)
save_checkpoint(model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
)
# synthesize a full voice
wav_path = train_data[random.randrange(0, len(train_data))][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(
ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
)
sample_wav = model.generate(ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {
"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T),
}
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio
tb_logger.tb_train_audios(
global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
global_step, {
"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
)
tb_logger.tb_train_figures(global_step, figures)
@ -234,17 +230,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
x, m, y = format_data(data)
x_input, mels, y_coarse = format_data(data)
loader_time = time.time() - end_time
global_step += 1
y_hat = model(x, m)
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y = y.float()
y = y.unsqueeze(-1)
loss = criterion(y_hat, y)
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
loss = criterion(y_hat, y_coarse)
# Compute avg loss
# if num_gpus > 1:
# loss = reduce_tensor(loss.data, num_gpus)
@ -264,30 +260,31 @@ def evaluate(model, criterion, ap, global_step, epoch):
# print eval stats
if CONFIG.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
c_logger.print_eval_step(
num_iter, loss_dict, keep_avg.avg_values)
if epoch % CONFIG.test_every_epochs == 0:
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0:
# synthesize a part of data
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav[:22000])
sample_wav = model.generate(
ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
)
sample_wav = model.generate(ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
use_cuda
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {
"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T),
}
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio
tb_logger.tb_eval_audios(
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
global_step, {
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
)
tb_logger.tb_eval_figures(global_step, figures)
@ -372,7 +369,8 @@ def main(args): # pylint: disable=redefined-outer-name
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
model_wavernn.load_state_dict(model_dict)
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
print(" > Model restored from step %d" %
checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -393,7 +391,8 @@ def main(args): # pylint: disable=redefined-outer-name
_, global_step = train(
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
)
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(
model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict["avg_model_loss"]
best_loss = save_best_model(
@ -493,7 +492,8 @@ if __name__ == "__main__":
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
# write model desc to tensorboard
tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0)
tb_logger.tb_add_text("model-description",
CONFIG["run_description"], 0)
try:
main(args)

View File

@ -8,17 +8,16 @@ class WaveRNNDataset(Dataset):
WaveRNN Dataset searchs for all the wav files under root path.
"""
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad,
mode,
is_training=True,
verbose=False,
):
def __init__(self,
ap,
items,
seq_len,
hop_len,
pad,
mode,
is_training=True,
verbose=False,
):
self.ap = ap
self.item_list = items
@ -56,17 +55,19 @@ class WaveRNNDataset(Dataset):
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
max_offsets = [x[0].shape[-1] -
(mel_win + 2 * self.pad) for x in batch]
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
sig_offsets = [(offset + self.pad) *
self.hop_len for offset in mel_offsets]
mels = [
x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win]
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
for i, x in enumerate(batch)
]
coarse = [
x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1]
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
for i, x in enumerate(batch)
]
@ -79,7 +80,8 @@ class WaveRNNDataset(Dataset):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = (
2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
2 * coarse[:, : self.seq_len].float() /
(2 ** self.mode - 1.0) - 1.0
)
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)

View File

@ -39,7 +39,8 @@ class MelResNet(nn.Module):
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
super().__init__()
k_size = pad * 2 + 1
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
self.conv_in = nn.Conv1d(
in_dims, compute_dims, kernel_size=k_size, bias=False)
self.batch_norm = nn.BatchNorm1d(compute_dims)
self.layers = nn.ModuleList()
for _ in range(res_blocks):
@ -94,7 +95,8 @@ class UpsampleNetwork(nn.Module):
k_size = (1, scale * 2 + 1)
padding = (0, scale)
stretch = Stretch2d(scale, 1)
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
conv = nn.Conv2d(1, 1, kernel_size=k_size,
padding=padding, bias=False)
conv.weight.data.fill_(1.0 / k_size[1])
self.up_layers.append(stretch)
self.up_layers.append(conv)
@ -110,7 +112,7 @@ class UpsampleNetwork(nn.Module):
m = m.unsqueeze(1)
for f in self.up_layers:
m = f(m)
m = m.squeeze(1)[:, :, self.indent : -self.indent]
m = m.squeeze(1)[:, :, self.indent: -self.indent]
return m.transpose(1, 2), aux
@ -123,7 +125,8 @@ class Upsample(nn.Module):
self.pad = pad
self.indent = pad * scale
self.use_aux_net = use_aux_net
self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
self.resnet = MelResNet(res_blocks, feat_dims,
compute_dims, res_out_dims, pad)
def forward(self, m):
if self.use_aux_net:
@ -137,7 +140,7 @@ class Upsample(nn.Module):
m = torch.nn.functional.interpolate(
m, scale_factor=self.scale, mode="linear", align_corners=True
)
m = m[:, :, self.indent : -self.indent]
m = m[:, :, self.indent: -self.indent]
m = m * 0.045 # empirically found
return m.transpose(1, 2), aux
@ -207,7 +210,8 @@ class WaveRNN(nn.Module):
if self.use_aux_net:
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims,
rnn_dims, batch_first=True)
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
self.fc3 = nn.Linear(fc_dims, self.n_classes)
@ -221,16 +225,16 @@ class WaveRNN(nn.Module):
def forward(self, x, mels):
bsize = x.size(0)
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
mels, aux = self.upsample(mels)
if self.use_aux_net:
aux_idx = [self.aux_dims * i for i in range(5)]
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
a1 = aux[:, :, aux_idx[0]: aux_idx[1]]
a2 = aux[:, :, aux_idx[1]: aux_idx[2]]
a3 = aux[:, :, aux_idx[2]: aux_idx[3]]
a4 = aux[:, :, aux_idx[3]: aux_idx[4]]
x = (
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
@ -256,19 +260,21 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x))
return self.fc3(x)
def generate(self, mels, batched, target, overlap):
def generate(self, mels, batched, target, overlap, use_cuda):
self.eval()
device = 'cuda' if use_cuda else 'cpu'
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
with torch.no_grad():
mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.hop_length
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
mels = self.pad_tensor(mels.transpose(
1, 2), pad=self.pad, side="both")
mels, aux = self.upsample(mels.transpose(1, 2))
if batched:
@ -278,13 +284,13 @@ class WaveRNN(nn.Module):
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.rnn_dims).cuda()
h2 = torch.zeros(b_size, self.rnn_dims).cuda()
x = torch.zeros(b_size, 1).cuda()
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
x = torch.zeros(b_size, 1).to(device)
if self.use_aux_net:
d = self.aux_dims
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)]
for i in range(seq_len):
@ -319,11 +325,12 @@ class WaveRNN(nn.Module):
logits.unsqueeze(0).transpose(1, 2)
)
output.append(sample.view(-1))
x = sample.transpose(0, 1).cuda()
x = sample.transpose(0, 1).to(device)
elif self.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
sample = sample_from_gaussian(
logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).cuda()
x = sample.transpose(0, 1).to(device)
elif isinstance(self.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)
@ -332,7 +339,8 @@ class WaveRNN(nn.Module):
output.append(sample)
x = sample.unsqueeze(-1)
else:
raise RuntimeError("Unknown model mode value - ", self.mode)
raise RuntimeError(
"Unknown model mode value - ", self.mode)
if i % 100 == 0:
self.gen_display(i, seq_len, b_size, start)
@ -352,7 +360,7 @@ class WaveRNN(nn.Module):
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.hop_length)
output = output[:wave_len]
output[-20 * self.hop_length :] *= fade_out
output[-20 * self.hop_length:] *= fade_out
self.train()
return output
@ -366,7 +374,6 @@ class WaveRNN(nn.Module):
)
def fold_with_overlap(self, x, target, overlap):
"""Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold()
Args:
@ -398,7 +405,7 @@ class WaveRNN(nn.Module):
padding = target + 2 * overlap - remaining
x = self.pad_tensor(x, padding, side="after")
folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
# Get the values for the folded tensor
for i in range(num_folds):
@ -423,16 +430,15 @@ class WaveRNN(nn.Module):
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == "both" else t + pad
padded = torch.zeros(b, total, c).cuda()
padded = torch.zeros(b, total, c).to(x.device)
if side in ("before", "both"):
padded[:, pad : pad + t, :] = x
padded[:, pad: pad + t, :] = x
elif side == "after":
padded[:, :t, :] = x
return padded
@staticmethod
def xfade_and_unfold(y, target, overlap):
"""Applies a crossfade and unfolds into a 1d array.
Args:
y (ndarry) : Batched sequences of audio samples