compute audio feat on dataload

pull/10/head
sanjaesc 2020-10-25 09:45:37 +01:00
parent 4d5da4b663
commit 4a989e3ceb
4 changed files with 243 additions and 203 deletions

View File

@ -29,8 +29,8 @@ from TTS.utils.generic_utils import (
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import (
find_feat_files,
load_wav_feat_data,
preprocess_wav_files,
load_wav_data,
load_wav_feat_data
)
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
@ -41,15 +41,16 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
if is_val and not CONFIG.run_eval:
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveRNNDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=CONFIG.seq_len,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad=CONFIG.padding,
mode=CONFIG.mode,
pad=c.padding,
mode=c.mode,
mulaw=c.mulaw,
is_training=not is_val,
verbose=verbose,
)
@ -57,10 +58,10 @@ def setup_loader(ap, is_val=False, verbose=False):
loader = DataLoader(dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_val_loader_workers
batch_size=c.batch_size,
num_workers=c.num_val_loader_workers
if is_val
else CONFIG.num_loader_workers,
else c.num_loader_workers,
pin_memory=True,
)
return loader
@ -89,9 +90,9 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(len(data_loader.dataset) /
(CONFIG.batch_size * num_gpus))
(c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
# train loop
@ -102,9 +103,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
loader_time = time.time() - end_time
global_step += 1
##################
# MODEL TRAINING #
##################
y_hat = model(x_input, mels)
if isinstance(model.mode, int):
@ -112,7 +110,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
else:
y_coarse = y_coarse.float()
y_coarse = y_coarse.unsqueeze(-1)
# m_scaled, _ = model.upsample(m)
# compute losses
loss = criterion(y_hat, y_coarse)
@ -120,11 +117,11 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
raise RuntimeError(" [!] None loss. Exiting ...")
optimizer.zero_grad()
loss.backward()
if CONFIG.grad_clip > 0:
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), CONFIG.grad_clip)
model.parameters(), c.grad_clip)
optimizer.step()
if scheduler is not None:
scheduler.step()
@ -144,7 +141,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
keep_avg.update_values(update_train_values)
# print training stats
if global_step % CONFIG.print_step == 0:
if global_step % c.print_step == 0:
log_dict = {"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
@ -164,8 +161,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % CONFIG.save_step == 0:
if CONFIG.checkpoint:
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model,
optimizer,
@ -180,28 +177,30 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
)
# synthesize a full voice
wav_path = train_data[random.randrange(0, len(train_data))][0]
rand_idx = random.randrange(0, len(train_data))
wav_path = train_data[rand_idx] if not isinstance(
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
c.batched,
c.target_samples,
c.overlap_samples,
use_cuda
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T)
}
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
tb_logger.tb_train_audios(
global_step, {
"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
"train/audio": sample_wav}, c.audio["sample_rate"]
)
tb_logger.tb_train_figures(global_step, figures)
end_time = time.time()
# print epoch stats
@ -259,34 +258,35 @@ def evaluate(model, criterion, ap, global_step, epoch):
keep_avg.update_values(update_eval_values)
# print eval stats
if CONFIG.print_eval:
if c.print_eval:
c_logger.print_eval_step(
num_iter, loss_dict, keep_avg.avg_values)
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0:
# synthesize a part of data
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
if epoch % c.test_every_epochs == 0 and epoch != 0:
# synthesize a full voice
rand_idx = random.randrange(0, len(eval_data))
wav_path = eval_data[rand_idx] if not isinstance(
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav[:22000])
ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
c.batched,
c.target_samples,
c.overlap_samples,
use_cuda
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio
tb_logger.tb_eval_audios(
global_step, {
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
"eval/audio": sample_wav}, c.audio["sample_rate"]
)
# compute spectrograms
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T)
}
tb_logger.tb_eval_figures(global_step, figures)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -299,53 +299,62 @@ def main(args): # pylint: disable=redefined-outer-name
global train_data, eval_data
# setup audio processor
ap = AudioProcessor(**CONFIG.audio)
ap = AudioProcessor(**c.audio)
print(f" > Loading wavs from: {CONFIG.data_path}")
if CONFIG.feature_path is not None:
print(f" > Loading features from: {CONFIG.feature_path}")
# print(f" > Loading wavs from: {c.data_path}")
# if c.feature_path is not None:
# print(f" > Loading features from: {c.feature_path}")
# eval_data, train_data = load_wav_feat_data(
# c.data_path, c.feature_path, c.eval_split_size
# )
# else:
# mel_feat_path = os.path.join(OUT_PATH, "mel")
# feat_data = find_feat_files(mel_feat_path)
# if feat_data:
# print(f" > Loading features from: {mel_feat_path}")
# eval_data, train_data = load_wav_feat_data(
# c.data_path, mel_feat_path, c.eval_split_size
# )
# else:
# print(" > No feature data found. Preprocessing...")
# # preprocessing feature data from given wav files
# preprocess_wav_files(OUT_PATH, CONFIG, ap)
# eval_data, train_data = load_wav_feat_data(
# c.data_path, mel_feat_path, c.eval_split_size
# )
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
)
c.data_path, c.feature_path, c.eval_split_size)
else:
mel_feat_path = os.path.join(OUT_PATH, "mel")
feat_data = find_feat_files(mel_feat_path)
if feat_data:
print(f" > Loading features from: {mel_feat_path}")
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
else:
print(" > No feature data found. Preprocessing...")
# preprocessing feature data from given wav files
preprocess_wav_files(OUT_PATH, CONFIG, ap)
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
eval_data, train_data = load_wav_data(
c.data_path, c.eval_split_size)
# setup model
model_wavernn = setup_wavernn(CONFIG)
model_wavernn = setup_wavernn(c)
# define train functions
if CONFIG.mode == "mold":
if c.mode == "mold":
criterion = discretized_mix_logistic_loss
elif CONFIG.mode == "gauss":
elif c.mode == "gauss":
criterion = gaussian_loss
elif isinstance(CONFIG.mode, int):
elif isinstance(c.mode, int):
criterion = torch.nn.CrossEntropyLoss()
if use_cuda:
model_wavernn.cuda()
if isinstance(CONFIG.mode, int):
if isinstance(c.mode, int):
criterion.cuda()
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
scheduler = None
if "lr_scheduler" in CONFIG:
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
scheduler = scheduler(optimizer, **CONFIG.lr_scheduler_params)
if "lr_scheduler" in c:
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
# slow start for the first 5 epochs
# lr_lambda = lambda epoch: min(epoch / CONFIG.warmup_steps, 1)
# lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# restore any checkpoint
@ -366,7 +375,7 @@ def main(args): # pylint: disable=redefined-outer-name
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model_wavernn.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model_wavernn.load_state_dict(model_dict)
print(" > Model restored from step %d" %
@ -386,11 +395,10 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float("inf")
global_step = args.restore_step
for epoch in range(0, CONFIG.epochs):
c_logger.print_epoch_start(epoch, CONFIG.epochs)
_, global_step = train(
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model_wavernn, optimizer,
criterion, scheduler, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(
model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
@ -462,14 +470,14 @@ if __name__ == "__main__":
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
CONFIG = load_config(args.config_path)
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path
if args.continue_path == "":
OUT_PATH = create_experiment_folder(
CONFIG.output_path, CONFIG.run_name, args.debug
c.output_path, c.run_name, args.debug
)
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
@ -483,7 +491,7 @@ if __name__ == "__main__":
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields
args.config_path, os.path.join(OUT_PATH, "c.json"), new_fields
)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
@ -492,8 +500,7 @@ if __name__ == "__main__":
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
# write model desc to tensorboard
tb_logger.tb_add_text("model-description",
CONFIG["run_description"], 0)
tb_logger.tb_add_text("model-description", c["run_description"], 0)
try:
main(args)

View File

@ -2,93 +2,96 @@
"run_name": "wavernn_test",
"run_description": "wavernn_test training",
// AUDIO PARAMETERS
"audio":{
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
// AUDIO PARAMETERS
"audio": {
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Silence trimming
"do_trim_silence": false,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
// MelSpectrogram parameters
"num_mels": 80, // size of the mel spec frame.
"mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
"num_mels": 80, // size of the mel spec frame.
"mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
// Normalization parameters
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
"min_level_db": -100, // lower bound for normalization
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
"min_level_db": -100, // lower bound for normalization
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
// Generating / Synthesizing
// Generating / Synthesizing
"batched": true,
"target_samples": 11000, // target number of samples to be generated in each batch entry
"overlap_samples": 550, // number of samples for crossfading between batches
"target_samples": 11000, // target number of samples to be generated in each batch entry
"overlap_samples": 550, // number of samples for crossfading between batches
// DISTRIBUTED TRAINING
// "distributed":{
// "backend": "nccl",
// "url": "tcp:\/\/localhost:54321"
// },
// MODEL PARAMETERS
"use_aux_net": true,
"use_upsample_net": true,
"upsample_factors": [4, 8, 8], // this needs to correctly factorise hop_length
"seq_len": 1280, // has to be devideable by hop_length
"mode": "mold", // mold [string], gauss [string], bits [int]
"mulaw": false, // apply mulaw if mode is bits
"padding": 2, // pad the input for resnet to see wider input length
// MODEL MODE
"mode": 10, // mold [string], gauss [string], bits [int]
"mulaw": true, // apply mulaw if mode is bits
// DATASET
//"use_gta": true, // use computed gta features from the tts model
"data_path": "path/to/wav/files", // path containing training wav files
"feature_path": null, // path containing computed features from wav files if null compute them
// MODEL PARAMETERS
"wavernn_model_params": {
"rnn_dims": 512,
"fc_dims": 512,
"compute_dims": 128,
"res_out_dims": 128,
"num_res_blocks": 10,
"use_aux_net": true,
"use_upsample_net": true,
"upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
},
// TRAINING
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"epochs": 10000, // total number of epochs to train.
// DATASET
//"use_gta": true, // use computed gta features from the tts model
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech", // path containing training wav files
"feature_path": null, // path containing computed features from wav files if null compute them
"seq_len": 1280, // has to be devideable by hop_length
"padding": 2, // pad the input for resnet to see wider input length
// VALIDATION
// TRAINING
"batch_size": 64, // Batch size for training.
"epochs": 10000, // total number of epochs to train.
// VALIDATION
"run_eval": true,
"test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example)
"test_every_epochs": 10, // Test after set number of epochs (Test every 10 epochs for example)
// OPTIMIZER
"grad_clip": 4, // apply gradient clipping if > 0
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
// OPTIMIZER
"grad_clip": 4, // apply gradient clipping if > 0
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"lr_scheduler_params": {
"gamma": 0.5,
"milestones": [200000, 400000, 600000]
},
"lr": 1e-4, // initial learning rate
"lr": 1e-4, // initial learning rate
// TENSORBOARD and LOGGING
"print_step": 25, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
// TENSORBOARD and LOGGING
"print_step": 25, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
// DATA LOADING
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"eval_split_size": 50, // number of samples for testing
// DATA LOADING
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
"eval_split_size": 50, // number of samples for testing
// PATHS
// PATHS
"output_path": "output/training/path"
}

View File

@ -1,11 +1,13 @@
import torch
import numpy as np
from torch.utils.data import Dataset
from multiprocessing import Manager
class WaveRNNDataset(Dataset):
"""
WaveRNN Dataset searchs for all the wav files under root path.
WaveRNN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly.
"""
def __init__(self,
@ -15,16 +17,19 @@ class WaveRNNDataset(Dataset):
hop_len,
pad,
mode,
mulaw,
is_training=True,
verbose=False,
):
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
self.item_list = items
self.seq_len = seq_len
self.hop_len = hop_len
self.pad = pad
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
@ -36,22 +41,47 @@ class WaveRNNDataset(Dataset):
return item
def load_item(self, index):
wavpath, feat_path = self.item_list[index]
m = np.load(feat_path.replace("/quant/", "/mel/"))
# x = self.wav_cache[index]
if m.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
m = np.load(feat_path.replace("/quant/", "/mel/"))
if self.mode in ["gauss", "mold"]:
# x = np.load(feat_path.replace("/mel/", "/quant/"))
x = self.ap.load_wav(wavpath)
elif isinstance(self.mode, int):
x = np.load(feat_path.replace("/mel/", "/quant/"))
"""
load (audio, feat) couple if feature_path is set
else compute it on the fly
"""
if self.compute_feat:
wavpath = self.item_list[index]
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
if mel.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
audio = self.ap.load_wav(wavpath)
mel = self.ap.melspectrogram(audio)
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
return m, x
wavpath, feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < 5:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if self.mode in ["gauss", "mold"]:
x_input = self.ap.load_wav(wavpath)
elif isinstance(self.mode, int):
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
return mel, x_input
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
@ -79,10 +109,8 @@ class WaveRNNDataset(Dataset):
elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = (
2 * coarse[:, : self.seq_len].float() /
(2 ** self.mode - 1.0) - 1.0
)
x_input = (2 * coarse[:, : self.seq_len].float() /
(2 ** self.mode - 1.0) - 1.0)
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse

View File

@ -36,14 +36,14 @@ class ResBlock(nn.Module):
class MelResNet(nn.Module):
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
super().__init__()
k_size = pad * 2 + 1
self.conv_in = nn.Conv1d(
in_dims, compute_dims, kernel_size=k_size, bias=False)
self.batch_norm = nn.BatchNorm1d(compute_dims)
self.layers = nn.ModuleList()
for _ in range(res_blocks):
for _ in range(num_res_blocks):
self.layers.append(ResBlock(compute_dims))
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
@ -76,7 +76,7 @@ class UpsampleNetwork(nn.Module):
feat_dims,
upsample_scales,
compute_dims,
res_blocks,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
@ -87,7 +87,7 @@ class UpsampleNetwork(nn.Module):
self.use_aux_net = use_aux_net
if use_aux_net:
self.resnet = MelResNet(
res_blocks, feat_dims, compute_dims, res_out_dims, pad
num_res_blocks, feat_dims, compute_dims, res_out_dims, pad
)
self.resnet_stretch = Stretch2d(self.total_scale, 1)
self.up_layers = nn.ModuleList()
@ -118,14 +118,14 @@ class UpsampleNetwork(nn.Module):
class Upsample(nn.Module):
def __init__(
self, scale, pad, res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
):
super().__init__()
self.scale = scale
self.pad = pad
self.indent = pad * scale
self.use_aux_net = use_aux_net
self.resnet = MelResNet(res_blocks, feat_dims,
self.resnet = MelResNet(num_res_blocks, feat_dims,
compute_dims, res_out_dims, pad)
def forward(self, m):
@ -147,23 +147,22 @@ class Upsample(nn.Module):
class WaveRNN(nn.Module):
def __init__(
self,
rnn_dims,
fc_dims,
mode,
mulaw,
pad,
use_aux_net,
use_upsample_net,
upsample_factors,
feat_dims,
compute_dims,
res_out_dims,
res_blocks,
hop_length,
sample_rate,
):
def __init__(self,
rnn_dims,
fc_dims,
mode,
mulaw,
pad,
use_aux_net,
use_upsample_net,
upsample_factors,
feat_dims,
compute_dims,
res_out_dims,
num_res_blocks,
hop_length,
sample_rate,
):
super().__init__()
self.mode = mode
self.mulaw = mulaw
@ -177,7 +176,7 @@ class WaveRNN(nn.Module):
elif self.mode == "gauss":
self.n_classes = 2
else:
raise RuntimeError(" > Unknown training mode")
raise RuntimeError("Unknown model mode value - ", self.mode)
self.rnn_dims = rnn_dims
self.aux_dims = res_out_dims // 4
@ -192,7 +191,7 @@ class WaveRNN(nn.Module):
feat_dims,
upsample_factors,
compute_dims,
res_blocks,
num_res_blocks,
res_out_dims,
pad,
use_aux_net,
@ -201,7 +200,7 @@ class WaveRNN(nn.Module):
self.upsample = Upsample(
hop_length,
pad,
res_blocks,
num_res_blocks,
feat_dims,
compute_dims,
res_out_dims,
@ -260,7 +259,7 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x))
return self.fc3(x)
def generate(self, mels, batched, target, overlap, use_cuda):
def generate(self, mels, batched, target, overlap, use_cuda=False):
self.eval()
device = 'cuda' if use_cuda else 'cpu'
@ -360,7 +359,9 @@ class WaveRNN(nn.Module):
# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.hop_length)
output = output[:wave_len]
output[-20 * self.hop_length:] *= fade_out
if wave_len > len(fade_out):
output[-20 * self.hop_length:] *= fade_out
self.train()
return output
@ -405,7 +406,8 @@ class WaveRNN(nn.Module):
padding = target + 2 * overlap - remaining
x = self.pad_tensor(x, padding, side="after")
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
folded = torch.zeros(num_folds, target + 2 *
overlap, features).to(x.device)
# Get the values for the folded tensor
for i in range(num_folds):