small fixes

pull/10/head
erogol 2020-12-30 14:20:08 +01:00
parent 5901a00576
commit eb555855e4
2 changed files with 15 additions and 12 deletions

View File

@ -15,7 +15,6 @@ class PositionalEncoding(nn.Module):
dropout (float): dropout parameter
dim (int): embedding size
"""
def __init__(self, dim, dropout=0.0, max_len=5000):
super().__init__()
if dim % 2 != 0:
@ -117,33 +116,36 @@ class Encoder(nn.Module):
"""
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.in_channels = in_hidden_channels
self.hidden_channels = in_hidden_channels
self.encoder_type = encoder_type
self.c_in_channels = c_in_channels
# init encoder
if encoder_type.lower() == "transformer":
# optional convolutional prenet
self.pre = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
self.pre = ConvLayerNorm(self.in_channels,
self.hidden_channels,
self.hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
# text encoder
self.encoder = Transformer(hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
self.encoder = Transformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn':
self.pre = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
self.encoder = ResidualConvBNBlock(hidden_channels,
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
nn.ReLU())
self.encoder = ResidualConvBNBlock(self.hidden_channels,
**encoder_params)
else:
raise NotImplementedError(' [!] encoder type not implemented.')
# final projection layers
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
self.post_bn = nn.BatchNorm1d(hidden_channels)
self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
1)
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
# TODO: implement multi-speaker

View File

@ -126,6 +126,7 @@ class GlowTts(nn.Module):
x_lenghts: B
y: B x C x T
y_lengths: B
g: B x C or B
"""
y_max_length = y.size(2)
# norm speaker embeddings
@ -133,7 +134,7 @@ class GlowTts(nn.Module):
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,