mirror of https://github.com/coqui-ai/TTS.git
added to device cpu/gpu + formatting
parent
ea9d8755de
commit
91e5f8b63d
|
@ -44,43 +44,41 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
if is_val and not CONFIG.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=CONFIG.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=CONFIG.padding,
|
||||
mode=CONFIG.mode,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset = WaveRNNDataset(ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=CONFIG.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=CONFIG.padding,
|
||||
mode=CONFIG.mode,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=CONFIG.batch_size,
|
||||
num_workers=CONFIG.num_val_loader_workers
|
||||
if is_val
|
||||
else CONFIG.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
loader = DataLoader(dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=CONFIG.batch_size,
|
||||
num_workers=CONFIG.num_val_loader_workers
|
||||
if is_val
|
||||
else CONFIG.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
x = data[0]
|
||||
m = data[1]
|
||||
y = data[2]
|
||||
x_input = data[0]
|
||||
mels = data[1]
|
||||
y_coarse = data[2]
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
x = x.cuda(non_blocking=True)
|
||||
m = m.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
x_input = x_input.cuda(non_blocking=True)
|
||||
mels = mels.cuda(non_blocking=True)
|
||||
y_coarse = y_coarse.cuda(non_blocking=True)
|
||||
|
||||
return x, m, y
|
||||
return x_input, mels, y_coarse
|
||||
|
||||
|
||||
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||
|
@ -90,7 +88,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) /
|
||||
(CONFIG.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -99,30 +98,31 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
print(" > Training", flush=True)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
x, m, y = format_data(data)
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
##################
|
||||
# MODEL TRAINING #
|
||||
##################
|
||||
y_hat = model(x, m)
|
||||
y_hat = model(x_input, mels)
|
||||
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y = y.float()
|
||||
y = y.unsqueeze(-1)
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# m_scaled, _ = model.upsample(m)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y)
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
if loss.item() is None:
|
||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if CONFIG.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), CONFIG.grad_clip)
|
||||
|
||||
optimizer.step()
|
||||
if scheduler is not None:
|
||||
|
@ -145,19 +145,17 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
|
||||
# print training stats
|
||||
if global_step % CONFIG.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(
|
||||
batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
log_dict = {"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
|
@ -169,40 +167,38 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
if global_step % CONFIG.save_step == 0:
|
||||
if CONFIG.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
)
|
||||
save_checkpoint(model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
sample_wav = model.generate(
|
||||
ground_mel,
|
||||
CONFIG.batched,
|
||||
CONFIG.target_samples,
|
||||
CONFIG.overlap_samples,
|
||||
)
|
||||
sample_wav = model.generate(ground_mel,
|
||||
CONFIG.batched,
|
||||
CONFIG.target_samples,
|
||||
CONFIG.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_train_audios(
|
||||
global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||
global_step, {
|
||||
"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||
)
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
@ -234,17 +230,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
# format data
|
||||
x, m, y = format_data(data)
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
y_hat = model(x, m)
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y = y.float()
|
||||
y = y.unsqueeze(-1)
|
||||
loss = criterion(y_hat, y)
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
# Compute avg loss
|
||||
# if num_gpus > 1:
|
||||
# loss = reduce_tensor(loss.data, num_gpus)
|
||||
|
@ -264,30 +260,31 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
|
||||
# print eval stats
|
||||
if CONFIG.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_eval_step(
|
||||
num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if epoch % CONFIG.test_every_epochs == 0:
|
||||
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0:
|
||||
# synthesize a part of data
|
||||
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav[:22000])
|
||||
sample_wav = model.generate(
|
||||
ground_mel,
|
||||
CONFIG.batched,
|
||||
CONFIG.target_samples,
|
||||
CONFIG.overlap_samples,
|
||||
)
|
||||
sample_wav = model.generate(ground_mel,
|
||||
CONFIG.batched,
|
||||
CONFIG.target_samples,
|
||||
CONFIG.overlap_samples,
|
||||
use_cuda
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_eval_audios(
|
||||
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||
global_step, {
|
||||
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||
)
|
||||
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
@ -372,7 +369,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
|
||||
model_wavernn.load_state_dict(model_dict)
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
print(" > Model restored from step %d" %
|
||||
checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
@ -393,7 +391,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
_, global_step = train(
|
||||
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(
|
||||
model_wavernn, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
||||
best_loss = save_best_model(
|
||||
|
@ -493,7 +492,8 @@ if __name__ == "__main__":
|
|||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0)
|
||||
tb_logger.tb_add_text("model-description",
|
||||
CONFIG["run_description"], 0)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -8,17 +8,16 @@ class WaveRNNDataset(Dataset):
|
|||
WaveRNN Dataset searchs for all the wav files under root path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad,
|
||||
mode,
|
||||
is_training=True,
|
||||
verbose=False,
|
||||
):
|
||||
def __init__(self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad,
|
||||
mode,
|
||||
is_training=True,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
|
@ -56,17 +55,19 @@ class WaveRNNDataset(Dataset):
|
|||
|
||||
def collate(self, batch):
|
||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
|
||||
max_offsets = [x[0].shape[-1] -
|
||||
(mel_win + 2 * self.pad) for x in batch]
|
||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
|
||||
sig_offsets = [(offset + self.pad) *
|
||||
self.hop_len for offset in mel_offsets]
|
||||
|
||||
mels = [
|
||||
x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win]
|
||||
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
|
||||
for i, x in enumerate(batch)
|
||||
]
|
||||
|
||||
coarse = [
|
||||
x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1]
|
||||
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
|
||||
for i, x in enumerate(batch)
|
||||
]
|
||||
|
||||
|
@ -79,7 +80,8 @@ class WaveRNNDataset(Dataset):
|
|||
coarse = np.stack(coarse).astype(np.int64)
|
||||
coarse = torch.LongTensor(coarse)
|
||||
x_input = (
|
||||
2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
|
||||
2 * coarse[:, : self.seq_len].float() /
|
||||
(2 ** self.mode - 1.0) - 1.0
|
||||
)
|
||||
y_coarse = coarse[:, 1:]
|
||||
mels = torch.FloatTensor(mels)
|
||||
|
|
|
@ -39,7 +39,8 @@ class MelResNet(nn.Module):
|
|||
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||
super().__init__()
|
||||
k_size = pad * 2 + 1
|
||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||
self.conv_in = nn.Conv1d(
|
||||
in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(res_blocks):
|
||||
|
@ -94,7 +95,8 @@ class UpsampleNetwork(nn.Module):
|
|||
k_size = (1, scale * 2 + 1)
|
||||
padding = (0, scale)
|
||||
stretch = Stretch2d(scale, 1)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
||||
conv = nn.Conv2d(1, 1, kernel_size=k_size,
|
||||
padding=padding, bias=False)
|
||||
conv.weight.data.fill_(1.0 / k_size[1])
|
||||
self.up_layers.append(stretch)
|
||||
self.up_layers.append(conv)
|
||||
|
@ -110,7 +112,7 @@ class UpsampleNetwork(nn.Module):
|
|||
m = m.unsqueeze(1)
|
||||
for f in self.up_layers:
|
||||
m = f(m)
|
||||
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
||||
m = m.squeeze(1)[:, :, self.indent: -self.indent]
|
||||
return m.transpose(1, 2), aux
|
||||
|
||||
|
||||
|
@ -123,7 +125,8 @@ class Upsample(nn.Module):
|
|||
self.pad = pad
|
||||
self.indent = pad * scale
|
||||
self.use_aux_net = use_aux_net
|
||||
self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
||||
self.resnet = MelResNet(res_blocks, feat_dims,
|
||||
compute_dims, res_out_dims, pad)
|
||||
|
||||
def forward(self, m):
|
||||
if self.use_aux_net:
|
||||
|
@ -137,7 +140,7 @@ class Upsample(nn.Module):
|
|||
m = torch.nn.functional.interpolate(
|
||||
m, scale_factor=self.scale, mode="linear", align_corners=True
|
||||
)
|
||||
m = m[:, :, self.indent : -self.indent]
|
||||
m = m[:, :, self.indent: -self.indent]
|
||||
m = m * 0.045 # empirically found
|
||||
|
||||
return m.transpose(1, 2), aux
|
||||
|
@ -207,7 +210,8 @@ class WaveRNN(nn.Module):
|
|||
if self.use_aux_net:
|
||||
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
||||
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
|
||||
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims,
|
||||
rnn_dims, batch_first=True)
|
||||
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
||||
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
||||
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
||||
|
@ -221,16 +225,16 @@ class WaveRNN(nn.Module):
|
|||
|
||||
def forward(self, x, mels):
|
||||
bsize = x.size(0)
|
||||
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
||||
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
||||
h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||
h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||
mels, aux = self.upsample(mels)
|
||||
|
||||
if self.use_aux_net:
|
||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
||||
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
||||
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
||||
a1 = aux[:, :, aux_idx[0]: aux_idx[1]]
|
||||
a2 = aux[:, :, aux_idx[1]: aux_idx[2]]
|
||||
a3 = aux[:, :, aux_idx[2]: aux_idx[3]]
|
||||
a4 = aux[:, :, aux_idx[3]: aux_idx[4]]
|
||||
|
||||
x = (
|
||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||
|
@ -256,19 +260,21 @@ class WaveRNN(nn.Module):
|
|||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def generate(self, mels, batched, target, overlap):
|
||||
def generate(self, mels, batched, target, overlap, use_cuda):
|
||||
|
||||
self.eval()
|
||||
device = 'cuda' if use_cuda else 'cpu'
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
rnn2 = self.get_gru_cell(self.rnn2)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
||||
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
||||
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.hop_length
|
||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
|
||||
mels = self.pad_tensor(mels.transpose(
|
||||
1, 2), pad=self.pad, side="both")
|
||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||
|
||||
if batched:
|
||||
|
@ -278,13 +284,13 @@ class WaveRNN(nn.Module):
|
|||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.rnn_dims).cuda()
|
||||
h2 = torch.zeros(b_size, self.rnn_dims).cuda()
|
||||
x = torch.zeros(b_size, 1).cuda()
|
||||
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||
x = torch.zeros(b_size, 1).to(device)
|
||||
|
||||
if self.use_aux_net:
|
||||
d = self.aux_dims
|
||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
||||
aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)]
|
||||
|
||||
for i in range(seq_len):
|
||||
|
||||
|
@ -319,11 +325,12 @@ class WaveRNN(nn.Module):
|
|||
logits.unsqueeze(0).transpose(1, 2)
|
||||
)
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).cuda()
|
||||
x = sample.transpose(0, 1).to(device)
|
||||
elif self.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
sample = sample_from_gaussian(
|
||||
logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).cuda()
|
||||
x = sample.transpose(0, 1).to(device)
|
||||
elif isinstance(self.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
@ -332,7 +339,8 @@ class WaveRNN(nn.Module):
|
|||
output.append(sample)
|
||||
x = sample.unsqueeze(-1)
|
||||
else:
|
||||
raise RuntimeError("Unknown model mode value - ", self.mode)
|
||||
raise RuntimeError(
|
||||
"Unknown model mode value - ", self.mode)
|
||||
|
||||
if i % 100 == 0:
|
||||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
@ -352,7 +360,7 @@ class WaveRNN(nn.Module):
|
|||
# Fade-out at the end to avoid signal cutting out suddenly
|
||||
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
||||
output = output[:wave_len]
|
||||
output[-20 * self.hop_length :] *= fade_out
|
||||
output[-20 * self.hop_length:] *= fade_out
|
||||
|
||||
self.train()
|
||||
return output
|
||||
|
@ -366,7 +374,6 @@ class WaveRNN(nn.Module):
|
|||
)
|
||||
|
||||
def fold_with_overlap(self, x, target, overlap):
|
||||
|
||||
"""Fold the tensor with overlap for quick batched inference.
|
||||
Overlap will be used for crossfading in xfade_and_unfold()
|
||||
Args:
|
||||
|
@ -398,7 +405,7 @@ class WaveRNN(nn.Module):
|
|||
padding = target + 2 * overlap - remaining
|
||||
x = self.pad_tensor(x, padding, side="after")
|
||||
|
||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
|
||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
||||
|
||||
# Get the values for the folded tensor
|
||||
for i in range(num_folds):
|
||||
|
@ -423,16 +430,15 @@ class WaveRNN(nn.Module):
|
|||
# i.e., it won't generalise to other shapes/dims
|
||||
b, t, c = x.size()
|
||||
total = t + 2 * pad if side == "both" else t + pad
|
||||
padded = torch.zeros(b, total, c).cuda()
|
||||
padded = torch.zeros(b, total, c).to(x.device)
|
||||
if side in ("before", "both"):
|
||||
padded[:, pad : pad + t, :] = x
|
||||
padded[:, pad: pad + t, :] = x
|
||||
elif side == "after":
|
||||
padded[:, :t, :] = x
|
||||
return padded
|
||||
|
||||
@staticmethod
|
||||
def xfade_and_unfold(y, target, overlap):
|
||||
|
||||
"""Applies a crossfade and unfolds into a 1d array.
|
||||
Args:
|
||||
y (ndarry) : Batched sequences of audio samples
|
||||
|
|
Loading…
Reference in New Issue