argument rename

pull/10/head
erogol 2020-12-30 14:19:23 +01:00
parent 4ef083f0f1
commit 5901a00576
1 changed files with 6 additions and 7 deletions

View File

@ -53,10 +53,9 @@ class PositionalEncoding(nn.Module):
if self.pe.size(2) < x.size(2): if self.pe.size(2) < x.size(2):
raise RuntimeError( raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is" f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument." f" limited to {self.pe.size(2)}. See max_len argument.")
)
if mask is not None: if mask is not None:
pos_enc = (self.pe[:, : ,:x.size(2)] * mask) pos_enc = (self.pe[:, :, :x.size(2)] * mask)
else: else:
pos_enc = self.pe[:, :, :x.size(2)] pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc x = x + pos_enc
@ -71,10 +70,10 @@ class Encoder(nn.Module):
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__( def __init__(
self, self,
hidden_channels, in_hidden_channels,
out_channels, out_channels,
encoder_type='residual_conv_bn', encoder_type='residual_conv_bn',
encoder_params = { encoder_params={
"kernel_size": 4, "kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1], "dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2, "num_conv_blocks": 2,
@ -86,7 +85,7 @@ class Encoder(nn.Module):
Args: Args:
num_chars (int): number of characters. num_chars (int): number of characters.
out_channels (int): number of output channels. out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size. in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
encoder_params (dict): model parameters for specified encoder type. encoder_params (dict): model parameters for specified encoder type.
c_in_channels (int): number of channels for conditional input. c_in_channels (int): number of channels for conditional input.
@ -96,7 +95,7 @@ class Encoder(nn.Module):
for 'transformer' for 'transformer'
encoder_params={ encoder_params={
'hidden_channels_ffn': 768, 'hidden_channels_ffn': 128,
'num_heads': 2, 'num_heads': 2,
"kernel_size": 3, "kernel_size": 3,
"dropout_p": 0.1, "dropout_p": 0.1,