align tts MDN layer

pull/373/head
Eren Gölge 2021-03-03 15:41:21 +01:00 committed by Eren Gölge
parent 4396f8e2da
commit a831468cab
1 changed files with 25 additions and 0 deletions

View File

@ -0,0 +1,25 @@
import torch
from torch import nn
from ..generic.normalization import LayerNorm
class MDNBlock(nn.Module):
"""Mixture of Density Network implementation
https://arxiv.org/pdf/2003.01950.pdf
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.out_channels = out_channels
self.mdn = nn.Sequential(nn.Conv1d(in_channels, in_channels, 1),
LayerNorm(in_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv1d(in_channels, out_channels, 1))
def forward(self, x):
mu_sigma = self.mdn(x)
# TODO: check this sigmoid
# mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
mu = mu_sigma[:, :self.out_channels//2, :]
log_sigma = mu_sigma[:, self.out_channels//2:, :]
return mu, log_sigma