From 5901a005767c6bfa800ae742ae68e9de030aaa67 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 30 Dec 2020 14:19:23 +0100 Subject: [PATCH] argument rename --- TTS/tts/layers/speedy_speech/encoder.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/speedy_speech/encoder.py b/TTS/tts/layers/speedy_speech/encoder.py index 755c8521..121c6d9e 100644 --- a/TTS/tts/layers/speedy_speech/encoder.py +++ b/TTS/tts/layers/speedy_speech/encoder.py @@ -53,10 +53,9 @@ class PositionalEncoding(nn.Module): if self.pe.size(2) < x.size(2): raise RuntimeError( f"Sequence is {x.size(2)} but PositionalEncoding is" - f" limited to {self.pe.size(2)}. See max_len argument." - ) + f" limited to {self.pe.size(2)}. See max_len argument.") if mask is not None: - pos_enc = (self.pe[:, : ,:x.size(2)] * mask) + pos_enc = (self.pe[:, :, :x.size(2)] * mask) else: pos_enc = self.pe[:, :, :x.size(2)] x = x + pos_enc @@ -71,10 +70,10 @@ class Encoder(nn.Module): # pylint: disable=dangerous-default-value def __init__( self, - hidden_channels, + in_hidden_channels, out_channels, encoder_type='residual_conv_bn', - encoder_params = { + encoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, @@ -86,7 +85,7 @@ class Encoder(nn.Module): Args: num_chars (int): number of characters. out_channels (int): number of output channels. - hidden_channels (int): encoder's embedding size. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. encoder_params (dict): model parameters for specified encoder type. c_in_channels (int): number of channels for conditional input. @@ -96,7 +95,7 @@ class Encoder(nn.Module): for 'transformer' encoder_params={ - 'hidden_channels_ffn': 768, + 'hidden_channels_ffn': 128, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1,