mirror of https://github.com/coqui-ai/TTS.git
tune wavegrad to fine the best noise schedule for inferece
parent
d94782a076
commit
c80225544e
|
@ -0,0 +1,89 @@
|
|||
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
||||
import argparse
|
||||
from itertools import product as cartesian_product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
|
||||
parser.add_argument('--config_path', type=str, help='Path to model config file.')
|
||||
parser.add_argument('--data_path', type=str, help='Path to data directory.')
|
||||
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
|
||||
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
|
||||
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
|
||||
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
|
||||
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
|
||||
|
||||
# load config
|
||||
args = parser.parse_args()
|
||||
config = load_config(args.config_path)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# load dataset
|
||||
_, train_data = load_wav_data(args.data_path, 0)
|
||||
train_data = train_data[:args.num_samples]
|
||||
dataset = WaveGradDataset(ap=ap,
|
||||
items=train_data,
|
||||
seq_len=ap.hop_length * 100,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_full_clips,
|
||||
drop_last=False,
|
||||
num_workers=config.num_loader_workers,
|
||||
pin_memory=False)
|
||||
|
||||
# setup the model
|
||||
model = setup_generator(config)
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
|
||||
# setup optimization parameters
|
||||
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
|
||||
best_error = float('inf')
|
||||
best_schedule = None
|
||||
total_search_iter = len(base_values)**args.num_iter
|
||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||
model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base)
|
||||
for data in loader:
|
||||
mel, audio = data
|
||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||
|
||||
if args.use_cuda:
|
||||
y_hat = y_hat.cpu()
|
||||
y_hat = y_hat.numpy()
|
||||
|
||||
mel_hat = []
|
||||
for i in range(y_hat.shape[0]):
|
||||
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||
mel_hat.append(torch.from_numpy(m))
|
||||
|
||||
mel_hat = torch.stack(mel_hat)
|
||||
mse = torch.sum((mel - mel_hat) ** 2)
|
||||
if mse.item() < best_error:
|
||||
best_error = mse.item()
|
||||
best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base}
|
||||
print(" > Found a better schedule.")
|
||||
np.save(args.output_path, best_schedule)
|
||||
|
||||
|
|
@ -174,7 +174,7 @@ class AudioProcessor(object):
|
|||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
if key != 'sample_rate':
|
||||
if key not in ['sample_rate', 'trim_db']:
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
|
|
@ -81,11 +81,12 @@ class WaveGradDataset(Dataset):
|
|||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
|
||||
# correct audio length wrt segment length
|
||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||
mode='constant', constant_values=0.0)
|
||||
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
if self.return_segments:
|
||||
# correct audio length wrt segment length
|
||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||
mode='constant', constant_values=0.0)
|
||||
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
|
||||
# correct the audio length wrt hop length
|
||||
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
||||
|
@ -104,8 +105,26 @@ class WaveGradDataset(Dataset):
|
|||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
mel = mel[..., :-1]
|
||||
mel = mel[..., :-1] # ignore the padding
|
||||
|
||||
audio = torch.from_numpy(audio).float()
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
return (mel, audio)
|
||||
|
||||
|
||||
def collate_full_clips(self, batch):
|
||||
"""This is used in tune_wavegrad.py.
|
||||
It pads sequences to the max length."""
|
||||
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
|
||||
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
|
||||
|
||||
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
|
||||
audios = torch.zeros([len(batch), max_audio_length])
|
||||
|
||||
for idx, b in enumerate(batch):
|
||||
mel = b[0]
|
||||
audio = b[1]
|
||||
mels[idx, :, :mel.shape[1]] = mel
|
||||
audios[idx, :audio.shape[0]] = audio
|
||||
|
||||
return mels, audios
|
||||
|
|
|
@ -78,13 +78,21 @@ class Wavegrad(nn.Module):
|
|||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
def load_noise_schedule(self, path):
|
||||
sched = np.load(path, allow_pickle=True).item()
|
||||
self.compute_noise_level(**sched)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x):
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||
sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x)
|
||||
def inference(self, x, y_n=None):
|
||||
""" x: B x D X T """
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n -
|
||||
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1))
|
||||
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
if n > 0:
|
||||
z = torch.randn_like(y_n)
|
||||
y_n += self.sigma[n - 1] * z
|
||||
|
@ -105,9 +113,11 @@ class Wavegrad(nn.Module):
|
|||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def compute_noise_level(self, num_steps, min_val, max_val):
|
||||
def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None):
|
||||
"""Compute noise schedule parameters"""
|
||||
beta = np.linspace(min_val, max_val, num_steps)
|
||||
if base_vals is not None:
|
||||
beta *= base_vals
|
||||
alpha = 1 - beta
|
||||
alpha_hat = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
||||
|
|
Loading…
Reference in New Issue