update argument name external_speaker_embedding_dim -> speaker_embedding_dim

add inference_noise_scale argument to glow-tts
pull/441/head
Eren Gölge 2021-04-21 13:09:44 +02:00
parent 2da81f5bb6
commit d42748082a
1 changed files with 12 additions and 11 deletions

View File

@ -33,7 +33,7 @@ class GlowTTS(nn.Module):
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step. mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
encoder_type (str): encoder module type. encoder_type (str): encoder module type.
encoder_params (dict): encoder module parameters. encoder_params (dict): encoder module parameters.
external_speaker_embedding_dim (int): channels of external speaker embedding vectors. speaker_embedding_dim (int): channels of external speaker embedding vectors.
""" """
def __init__( def __init__(
@ -45,6 +45,7 @@ class GlowTTS(nn.Module):
hidden_channels_dp, hidden_channels_dp,
out_channels, out_channels,
num_flow_blocks_dec=12, num_flow_blocks_dec=12,
inference_noise_scale=0.33,
kernel_size_dec=5, kernel_size_dec=5,
dilation_rate=5, dilation_rate=5,
num_block_layers=4, num_block_layers=4,
@ -58,10 +59,9 @@ class GlowTTS(nn.Module):
mean_only=False, mean_only=False,
encoder_type="transformer", encoder_type="transformer",
encoder_params=None, encoder_params=None,
external_speaker_embedding_dim=None, speaker_embedding_dim=None,
): ):
super().__init__() super().__init__()
self.num_chars = num_chars self.num_chars = num_chars
self.hidden_channels_dp = hidden_channels_dp self.hidden_channels_dp = hidden_channels_dp
@ -80,19 +80,20 @@ class GlowTTS(nn.Module):
self.sigmoid_scale = sigmoid_scale self.sigmoid_scale = sigmoid_scale
self.mean_only = mean_only self.mean_only = mean_only
self.use_encoder_prenet = use_encoder_prenet self.use_encoder_prenet = use_encoder_prenet
self.inference_noise_scale = inference_noise_scale
# model constants. # model constants.
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference. self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech. self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
self.external_speaker_embedding_dim = external_speaker_embedding_dim self.speaker_embedding_dim = speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256 # if is a multispeaker and c_in_channels is 0, set to 256
if num_speakers > 1: if num_speakers > 1:
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim: if self.c_in_channels == 0 and not self.speaker_embedding_dim:
# TODO: make this adjustable # TODO: make this adjustable
self.c_in_channels = 256 self.c_in_channels = 256
elif self.external_speaker_embedding_dim: elif self.speaker_embedding_dim:
self.c_in_channels = self.external_speaker_embedding_dim self.c_in_channels = self.speaker_embedding_dim
self.encoder = Encoder( self.encoder = Encoder(
num_chars, num_chars,
@ -121,7 +122,7 @@ class GlowTTS(nn.Module):
c_in_channels=self.c_in_channels, c_in_channels=self.c_in_channels,
) )
if num_speakers > 1 and not external_speaker_embedding_dim: if num_speakers > 1 and not speaker_embedding_dim:
# speaker embedding layer # speaker embedding layer
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels) self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@ -151,7 +152,7 @@ class GlowTTS(nn.Module):
y_max_length = y.size(2) y_max_length = y.size(2)
# norm speaker embeddings # norm speaker embeddings
if g is not None: if g is not None:
if self.external_speaker_embedding_dim: if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1) g = F.normalize(g).unsqueeze(-1)
else: else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
@ -181,7 +182,7 @@ class GlowTTS(nn.Module):
@torch.no_grad() @torch.no_grad()
def inference(self, x, x_lengths, g=None): def inference(self, x, x_lengths, g=None):
if g is not None: if g is not None:
if self.external_speaker_embedding_dim: if self.speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1) g = F.normalize(g).unsqueeze(-1)
else: else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
@ -200,7 +201,7 @@ class GlowTTS(nn.Module):
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
# decoder pass # decoder pass
y, logdet = self.decoder(z, y_mask, g=g, reverse=True) y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1) attn = attn.squeeze(1).permute(0, 2, 1)