argument rename

pull/10/head
erogol 2020-12-30 14:19:23 +01:00
parent 4ef083f0f1
commit 5901a00576
1 changed files with 6 additions and 7 deletions

View File

@ -53,10 +53,9 @@ class PositionalEncoding(nn.Module):
if self.pe.size(2) < x.size(2):
raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument."
)
f" limited to {self.pe.size(2)}. See max_len argument.")
if mask is not None:
pos_enc = (self.pe[:, : ,:x.size(2)] * mask)
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
else:
pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc
@ -71,10 +70,10 @@ class Encoder(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
hidden_channels,
in_hidden_channels,
out_channels,
encoder_type='residual_conv_bn',
encoder_params = {
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
@ -86,7 +85,7 @@ class Encoder(nn.Module):
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size.
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
encoder_params (dict): model parameters for specified encoder type.
c_in_channels (int): number of channels for conditional input.
@ -96,7 +95,7 @@ class Encoder(nn.Module):
for 'transformer'
encoder_params={
'hidden_channels_ffn': 768,
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,