mirror of https://github.com/coqui-ai/TTS.git
Padding with funcitonal interface to match TF "SAME"
parent
4a741c64b3
commit
00c0c9cde6
|
@ -56,18 +56,20 @@ class BatchNormConv1d(nn.Module):
|
|||
padding,
|
||||
activation=None):
|
||||
super(BatchNormConv1d, self).__init__()
|
||||
self.padding = padding
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
padding=0,
|
||||
bias=False)
|
||||
# Following tensorflow's default parameters
|
||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.pad(x, self.padding)
|
||||
x = self.conv1d(x)
|
||||
if self.activation is not None:
|
||||
x = self.activation(x)
|
||||
|
@ -130,12 +132,12 @@ class CBHG(nn.Module):
|
|||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
padding=[(k - 1) // 2, k // 2],
|
||||
activation=self.relu) for k in range(1, K + 1)
|
||||
])
|
||||
# max pooling of conv bank
|
||||
# max pooling of conv bank, padding with nn.functional
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=0)
|
||||
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||
activations = [self.relu] * (len(conv_projections) - 1)
|
||||
activations += [None]
|
||||
|
@ -148,7 +150,7 @@ class CBHG(nn.Module):
|
|||
out_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
padding=[1, 1],
|
||||
activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
|
@ -181,11 +183,11 @@ class CBHG(nn.Module):
|
|||
outs = []
|
||||
for conv1d in self.conv1d_banks:
|
||||
out = conv1d(x)
|
||||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||
x = self.max_pool1d(x)[:, :, :T]
|
||||
x = nn.functional.pad(x, [0, 1])
|
||||
x = self.max_pool1d(x)
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
# (B, T_in, hid_feature)
|
||||
|
|
Loading…
Reference in New Issue