mirror of https://github.com/coqui-ai/TTS.git
argument rename
parent
4ef083f0f1
commit
5901a00576
|
@ -53,10 +53,9 @@ class PositionalEncoding(nn.Module):
|
|||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||
f" limited to {self.pe.size(2)}. See max_len argument."
|
||||
)
|
||||
f" limited to {self.pe.size(2)}. See max_len argument.")
|
||||
if mask is not None:
|
||||
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
|
||||
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
|
||||
else:
|
||||
pos_enc = self.pe[:, :, :x.size(2)]
|
||||
x = x + pos_enc
|
||||
|
@ -71,10 +70,10 @@ class Encoder(nn.Module):
|
|||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
in_hidden_channels,
|
||||
out_channels,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params = {
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
|
@ -86,7 +85,7 @@ class Encoder(nn.Module):
|
|||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): encoder's embedding size.
|
||||
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
|
||||
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||
encoder_params (dict): model parameters for specified encoder type.
|
||||
c_in_channels (int): number of channels for conditional input.
|
||||
|
@ -96,7 +95,7 @@ class Encoder(nn.Module):
|
|||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 768,
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
|
|
Loading…
Reference in New Issue