revert some of Hifigan generator updates

pull/422/head
Eren Gölge 2021-04-07 19:15:31 +02:00
parent 02bc776c35
commit 13dca6e6b6
1 changed files with 10 additions and 11 deletions

View File

@ -85,12 +85,12 @@ class ResBlock1(torch.nn.Module):
x: [B, C, T]
"""
for c1, c2 in zip(self.convs1, self.convs2):
o = F.leaky_relu(x, LRELU_SLOPE)
o = c1(o)
o = F.leaky_relu(o, LRELU_SLOPE)
o = c2(o)
o = o + x
return o
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
@ -134,10 +134,10 @@ class ResBlock2(torch.nn.Module):
def forward(self, x):
for c in self.convs:
o = F.leaky_relu(x, LRELU_SLOPE)
o = c(o)
o = o + x
return o
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
@ -223,7 +223,6 @@ class HifiganGenerator(torch.nn.Module):
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
breakpoint()
return o
@torch.no_grad()