mirror of https://github.com/coqui-ai/TTS.git
small fixes
parent
5901a00576
commit
eb555855e4
|
@ -15,7 +15,6 @@ class PositionalEncoding(nn.Module):
|
|||
dropout (float): dropout parameter
|
||||
dim (int): embedding size
|
||||
"""
|
||||
|
||||
def __init__(self, dim, dropout=0.0, max_len=5000):
|
||||
super().__init__()
|
||||
if dim % 2 != 0:
|
||||
|
@ -117,33 +116,36 @@ class Encoder(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.in_channels = in_hidden_channels
|
||||
self.hidden_channels = in_hidden_channels
|
||||
self.encoder_type = encoder_type
|
||||
self.c_in_channels = c_in_channels
|
||||
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
self.pre = ConvLayerNorm(self.in_channels,
|
||||
self.hidden_channels,
|
||||
self.hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
self.encoder = Transformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
self.pre = nn.Sequential(
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
|
||||
self.encoder = ResidualConvBNBlock(hidden_channels,
|
||||
nn.Conv1d(self.in_channels, self.hidden_channels, 1),
|
||||
nn.ReLU())
|
||||
self.encoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**encoder_params)
|
||||
else:
|
||||
raise NotImplementedError(' [!] encoder type not implemented.')
|
||||
|
||||
# final projection layers
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.post_bn = nn.BatchNorm1d(hidden_channels)
|
||||
self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels,
|
||||
1)
|
||||
self.post_bn = nn.BatchNorm1d(self.hidden_channels)
|
||||
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
# TODO: implement multi-speaker
|
||||
|
|
|
@ -126,6 +126,7 @@ class GlowTts(nn.Module):
|
|||
x_lenghts: B
|
||||
y: B x C x T
|
||||
y_lengths: B
|
||||
g: B x C or B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
|
@ -133,7 +134,7 @@ class GlowTts(nn.Module):
|
|||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
|
|
Loading…
Reference in New Issue