tune wavegrad to fine the best noise schedule for inferece

pull/10/head
erogol 2020-11-06 13:04:46 +01:00
parent d94782a076
commit c80225544e
4 changed files with 130 additions and 12 deletions

89
TTS/bin/tune_wavegrad.py Normal file
View File

@ -0,0 +1,89 @@
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
import argparse
from itertools import product as cartesian_product
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.models.wavegrad import Wavegrad
from TTS.vocoder.utils.generic_utils import setup_generator
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
parser.add_argument('--config_path', type=str, help='Path to model config file.')
parser.add_argument('--data_path', type=str, help='Path to data directory.')
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
# load config
args = parser.parse_args()
config = load_config(args.config_path)
# setup audio processor
ap = AudioProcessor(**config.audio)
# load dataset
_, train_data = load_wav_data(args.data_path, 0)
train_data = train_data[:args.num_samples]
dataset = WaveGradDataset(ap=ap,
items=train_data,
seq_len=ap.hop_length * 100,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=True,
return_segments=False,
use_noise_augment=False,
use_cache=False,
verbose=True)
loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_full_clips,
drop_last=False,
num_workers=config.num_loader_workers,
pin_memory=False)
# setup the model
model = setup_generator(config)
if args.use_cuda:
model.cuda()
# setup optimization parameters
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
best_error = float('inf')
best_schedule = None
total_search_iter = len(base_values)**args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base)
for data in loader:
mel, audio = data
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
if args.use_cuda:
y_hat = y_hat.cpu()
y_hat = y_hat.numpy()
mel_hat = []
for i in range(y_hat.shape[0]):
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
mel_hat.append(torch.from_numpy(m))
mel_hat = torch.stack(mel_hat)
mse = torch.sum((mel - mel_hat) ** 2)
if mse.item() < best_error:
best_error = mse.item()
best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base}
print(" > Found a better schedule.")
np.save(args.output_path, best_schedule)

View File

@ -174,7 +174,7 @@ class AudioProcessor(object):
for key in stats_config.keys():
if key in skip_parameters:
continue
if key != 'sample_rate':
if key not in ['sample_rate', 'trim_db']:
assert stats_config[key] == self.__dict__[key],\
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config

View File

@ -81,11 +81,12 @@ class WaveGradDataset(Dataset):
else:
audio = self.ap.load_wav(wavpath)
# correct audio length wrt segment length
if audio.shape[-1] < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
if self.return_segments:
# correct audio length wrt segment length
if audio.shape[-1] < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
# correct the audio length wrt hop length
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
@ -104,8 +105,26 @@ class WaveGradDataset(Dataset):
audio = audio + (1 / 32768) * torch.randn_like(audio)
mel = self.ap.melspectrogram(audio)
mel = mel[..., :-1]
mel = mel[..., :-1] # ignore the padding
audio = torch.from_numpy(audio).float()
mel = torch.from_numpy(mel).float().squeeze(0)
return (mel, audio)
def collate_full_clips(self, batch):
"""This is used in tune_wavegrad.py.
It pads sequences to the max length."""
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
audios = torch.zeros([len(batch), max_audio_length])
for idx, b in enumerate(batch):
mel = b[0]
audio = b[1]
mels[idx, :, :mel.shape[1]] = mel
audios[idx, :audio.shape[0]] = audio
return mels, audios

View File

@ -78,13 +78,21 @@ class Wavegrad(nn.Module):
x = self.out_conv(x)
return x
def load_noise_schedule(self, path):
sched = np.load(path, allow_pickle=True).item()
self.compute_noise_level(**sched)
@torch.no_grad()
def inference(self, x):
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x)
def inference(self, x, y_n=None):
""" x: B x D X T """
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n -
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1))
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z
@ -105,9 +113,11 @@ class Wavegrad(nn.Module):
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, num_steps, min_val, max_val):
def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None):
"""Compute noise schedule parameters"""
beta = np.linspace(min_val, max_val, num_steps)
if base_vals is not None:
beta *= base_vals
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)