mirror of https://github.com/coqui-ai/TTS.git
wavegrad model and layers refactoring
parent
dc2825dfb2
commit
b76a0be97a
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..layers.wavegrad import DBlock, FiLM, UBlock
|
||||
from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
|
||||
|
||||
|
||||
class Wavegrad(nn.Module):
|
||||
|
@ -10,8 +10,8 @@ class Wavegrad(nn.Module):
|
|||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
x_conv_channels=32,
|
||||
c_conv_channels=768,
|
||||
y_conv_channels=32,
|
||||
x_conv_channels=768,
|
||||
dblock_out_channels=[128, 128, 256, 512],
|
||||
ublock_out_channels=[512, 512, 256, 128, 128],
|
||||
upsample_factors=[5, 5, 3, 2, 2],
|
||||
|
@ -19,106 +19,87 @@ class Wavegrad(nn.Module):
|
|||
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
||||
super().__init__()
|
||||
|
||||
assert len(upsample_factors) == len(upsample_dilations)
|
||||
assert len(upsample_factors) == len(ublock_out_channels)
|
||||
self.hop_len = np.prod(upsample_factors)
|
||||
|
||||
# setup up-down sampling parameters
|
||||
self.hop_length = np.prod(upsample_factors)
|
||||
self.upsample_factors = upsample_factors
|
||||
self.downsample_factors = upsample_factors[::-1][:-1]
|
||||
|
||||
### define DBlocks, FiLM layers ###
|
||||
# dblocks
|
||||
self.dblocks = nn.ModuleList([
|
||||
nn.Conv1d(out_channels, x_conv_channels, 5, padding=2),
|
||||
Conv1d(1, y_conv_channels, 5, padding=2),
|
||||
])
|
||||
ic = x_conv_channels
|
||||
self.films = nn.ModuleList([])
|
||||
for oc, df in zip(dblock_out_channels, self.downsample_factors):
|
||||
# print('dblock(', ic, ', ', oc, ', ', df, ")")
|
||||
layer = DBlock(ic, oc, df)
|
||||
self.dblocks.append(layer)
|
||||
|
||||
# print('film(', ic, ', ', oc,")")
|
||||
layer = FiLM(ic, oc)
|
||||
self.films.append(layer)
|
||||
ic = y_conv_channels
|
||||
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
ic = oc
|
||||
# last FiLM block
|
||||
# print('film(', ic, ', ', dblock_out_channels[-1],")")
|
||||
self.films.append(FiLM(ic, dblock_out_channels[-1]))
|
||||
|
||||
### define UBlocks ###
|
||||
self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1)
|
||||
# film
|
||||
self.film = nn.ModuleList([])
|
||||
ic = y_conv_channels
|
||||
for oc in reversed(ublock_out_channels):
|
||||
self.film.append(FiLM(ic, oc))
|
||||
ic = oc
|
||||
|
||||
# ublocks
|
||||
self.ublocks = nn.ModuleList([])
|
||||
ic = c_conv_channels
|
||||
for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)):
|
||||
# print('ublock(', ic, ', ', oc, ', ', uf, ")")
|
||||
layer = UBlock(ic, oc, uf, upsample_dilations[idx])
|
||||
self.ublocks.append(layer)
|
||||
ic = x_conv_channels
|
||||
for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations):
|
||||
self.ublocks.append(UBlock(ic, oc, uf, ud))
|
||||
ic = oc
|
||||
|
||||
# define last layer
|
||||
# print(ic, 'last_conv--', out_channels)
|
||||
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
||||
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
|
||||
|
||||
# inference time noise schedule params
|
||||
self.S = 1000
|
||||
self.init_noise_schedule(self.S)
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
downsampled = []
|
||||
for film, layer in zip(self.film, self.dblocks):
|
||||
x = layer(x)
|
||||
downsampled.append(film(x, noise_scale))
|
||||
|
||||
|
||||
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01):
|
||||
"""compute noise schedule parameters"""
|
||||
device = self.last_conv.weight.device
|
||||
beta = torch.linspace(min_val, max_val, num_iter).to(device)
|
||||
alpha = 1 - beta
|
||||
alpha_cum = alpha.cumprod(dim=0)
|
||||
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
|
||||
|
||||
self.register_buffer('beta', beta)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('alpha_cum', alpha_cum)
|
||||
self.register_buffer('noise_level', noise_level)
|
||||
|
||||
def compute_noisy_x(self, x):
|
||||
B = x.shape[0]
|
||||
if len(x.shape) == 3:
|
||||
x = x.squeeze(1)
|
||||
s = torch.randint(1, self.S + 1, [B]).to(x).long()
|
||||
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
|
||||
noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a)
|
||||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(x)
|
||||
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def forward(self, x, c, noise_scale):
|
||||
assert len(c.shape) == 3 # B, C, T
|
||||
assert len(x.shape) == 3 # B, 1, T
|
||||
o = x
|
||||
shift_and_scales = []
|
||||
for film, dblock in zip(self.films, self.dblocks):
|
||||
o = dblock(o)
|
||||
shift_and_scales.append(film(o, noise_scale))
|
||||
|
||||
o = self.c_conv(c)
|
||||
for ublock, (film_shift, film_scale) in zip(self.ublocks,
|
||||
reversed(shift_and_scales)):
|
||||
o = ublock(o, film_shift, film_scale)
|
||||
o = self.last_conv(o)
|
||||
return o
|
||||
|
||||
def inference(self, c):
|
||||
with torch.no_grad():
|
||||
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
|
||||
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
c1 = 1 / self.alpha[n]**0.5
|
||||
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
||||
x = c1 * (x -
|
||||
c2 * self.forward(x, c, noise_scale[n]).squeeze(1))
|
||||
if n > 0:
|
||||
noise = torch.randn_like(x)
|
||||
sigma = ((1.0 - self.alpha_cum[n - 1]) /
|
||||
(1.0 - self.alpha_cum[n]) * self.beta[n])**0.5
|
||||
x += sigma * noise
|
||||
x = torch.clamp(x, -1.0, 1.0)
|
||||
x = self.x_conv(spectrogram)
|
||||
for layer, (film_shift, film_scale) in zip(self.ublocks,
|
||||
reversed(downsampled)):
|
||||
x = layer(x, film_shift, film_scale)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x):
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||
sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n -
|
||||
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1))
|
||||
if n > 0:
|
||||
z = torch.randn_like(y_n)
|
||||
y_n += self.sigma[n - 1] * z
|
||||
y_n.clamp_(-1.0, 1.0)
|
||||
return y_n
|
||||
|
||||
|
||||
def compute_y_n(self, y_0):
|
||||
self.noise_level = self.noise_level.to(y_0)
|
||||
if len(y_0.shape) == 3:
|
||||
y_0 = y_0.squeeze(1)
|
||||
s = torch.randint(1, self.num_steps + 1, [y_0.shape[0]])
|
||||
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
|
||||
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
|
||||
noise_scale = noise_scale.unsqueeze(1)
|
||||
noise = torch.randn_like(y_0)
|
||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||
|
||||
def compute_noise_level(self, num_steps, min_val, max_val):
|
||||
beta = np.linspace(min_val, max_val, num_steps)
|
||||
alpha = 1 - beta
|
||||
alpha_hat = np.cumprod(alpha)
|
||||
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
||||
|
||||
self.num_steps = num_steps
|
||||
self.beta = torch.tensor(beta.astype(np.float32))
|
||||
self.alpha = torch.tensor(alpha.astype(np.float32))
|
||||
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
|
||||
self.noise_level = torch.tensor(noise_level.astype(np.float32))
|
||||
|
||||
self.c1 = 1 / self.alpha**0.5
|
||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
|
||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5
|
||||
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ def setup_generator(c):
|
|||
in_channels=c['audio']['num_mels'],
|
||||
out_channels=1,
|
||||
x_conv_channels=c['model_params']['x_conv_channels'],
|
||||
c_conv_channels=c['model_params']['c_conv_channels'],
|
||||
y_conv_channels=c['model_params']['y_conv_channels'],
|
||||
dblock_out_channels=c['model_params']['dblock_out_channels'],
|
||||
ublock_out_channels=c['model_params']['ublock_out_channels'],
|
||||
upsample_factors=c['model_params']['upsample_factors'],
|
||||
|
|
Loading…
Reference in New Issue