mirror of https://github.com/coqui-ai/TTS.git
argument rename
parent
4ef083f0f1
commit
5901a00576
|
@ -53,10 +53,9 @@ class PositionalEncoding(nn.Module):
|
||||||
if self.pe.size(2) < x.size(2):
|
if self.pe.size(2) < x.size(2):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||||
f" limited to {self.pe.size(2)}. See max_len argument."
|
f" limited to {self.pe.size(2)}. See max_len argument.")
|
||||||
)
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
|
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
|
||||||
else:
|
else:
|
||||||
pos_enc = self.pe[:, :, :x.size(2)]
|
pos_enc = self.pe[:, :, :x.size(2)]
|
||||||
x = x + pos_enc
|
x = x + pos_enc
|
||||||
|
@ -71,10 +70,10 @@ class Encoder(nn.Module):
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_channels,
|
in_hidden_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
encoder_type='residual_conv_bn',
|
encoder_type='residual_conv_bn',
|
||||||
encoder_params = {
|
encoder_params={
|
||||||
"kernel_size": 4,
|
"kernel_size": 4,
|
||||||
"dilations": 4 * [1, 2, 4] + [1],
|
"dilations": 4 * [1, 2, 4] + [1],
|
||||||
"num_conv_blocks": 2,
|
"num_conv_blocks": 2,
|
||||||
|
@ -86,7 +85,7 @@ class Encoder(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
num_chars (int): number of characters.
|
num_chars (int): number of characters.
|
||||||
out_channels (int): number of output channels.
|
out_channels (int): number of output channels.
|
||||||
hidden_channels (int): encoder's embedding size.
|
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
|
||||||
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||||
encoder_params (dict): model parameters for specified encoder type.
|
encoder_params (dict): model parameters for specified encoder type.
|
||||||
c_in_channels (int): number of channels for conditional input.
|
c_in_channels (int): number of channels for conditional input.
|
||||||
|
@ -96,7 +95,7 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
for 'transformer'
|
for 'transformer'
|
||||||
encoder_params={
|
encoder_params={
|
||||||
'hidden_channels_ffn': 768,
|
'hidden_channels_ffn': 128,
|
||||||
'num_heads': 2,
|
'num_heads': 2,
|
||||||
"kernel_size": 3,
|
"kernel_size": 3,
|
||||||
"dropout_p": 0.1,
|
"dropout_p": 0.1,
|
||||||
|
|
Loading…
Reference in New Issue