mirror of https://github.com/coqui-ai/TTS.git
update argument name external_speaker_embedding_dim -> speaker_embedding_dim
add inference_noise_scale argument to glow-ttspull/441/head
parent
2da81f5bb6
commit
d42748082a
|
@ -33,7 +33,7 @@ class GlowTTS(nn.Module):
|
||||||
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
|
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
|
||||||
encoder_type (str): encoder module type.
|
encoder_type (str): encoder module type.
|
||||||
encoder_params (dict): encoder module parameters.
|
encoder_params (dict): encoder module parameters.
|
||||||
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -45,6 +45,7 @@ class GlowTTS(nn.Module):
|
||||||
hidden_channels_dp,
|
hidden_channels_dp,
|
||||||
out_channels,
|
out_channels,
|
||||||
num_flow_blocks_dec=12,
|
num_flow_blocks_dec=12,
|
||||||
|
inference_noise_scale=0.33,
|
||||||
kernel_size_dec=5,
|
kernel_size_dec=5,
|
||||||
dilation_rate=5,
|
dilation_rate=5,
|
||||||
num_block_layers=4,
|
num_block_layers=4,
|
||||||
|
@ -58,10 +59,9 @@ class GlowTTS(nn.Module):
|
||||||
mean_only=False,
|
mean_only=False,
|
||||||
encoder_type="transformer",
|
encoder_type="transformer",
|
||||||
encoder_params=None,
|
encoder_params=None,
|
||||||
external_speaker_embedding_dim=None,
|
speaker_embedding_dim=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_chars = num_chars
|
self.num_chars = num_chars
|
||||||
self.hidden_channels_dp = hidden_channels_dp
|
self.hidden_channels_dp = hidden_channels_dp
|
||||||
|
@ -80,19 +80,20 @@ class GlowTTS(nn.Module):
|
||||||
self.sigmoid_scale = sigmoid_scale
|
self.sigmoid_scale = sigmoid_scale
|
||||||
self.mean_only = mean_only
|
self.mean_only = mean_only
|
||||||
self.use_encoder_prenet = use_encoder_prenet
|
self.use_encoder_prenet = use_encoder_prenet
|
||||||
|
self.inference_noise_scale = inference_noise_scale
|
||||||
|
|
||||||
# model constants.
|
# model constants.
|
||||||
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
||||||
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
|
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||||
self.external_speaker_embedding_dim = external_speaker_embedding_dim
|
self.speaker_embedding_dim = speaker_embedding_dim
|
||||||
|
|
||||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim:
|
if self.c_in_channels == 0 and not self.speaker_embedding_dim:
|
||||||
# TODO: make this adjustable
|
# TODO: make this adjustable
|
||||||
self.c_in_channels = 256
|
self.c_in_channels = 256
|
||||||
elif self.external_speaker_embedding_dim:
|
elif self.speaker_embedding_dim:
|
||||||
self.c_in_channels = self.external_speaker_embedding_dim
|
self.c_in_channels = self.speaker_embedding_dim
|
||||||
|
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -121,7 +122,7 @@ class GlowTTS(nn.Module):
|
||||||
c_in_channels=self.c_in_channels,
|
c_in_channels=self.c_in_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_speakers > 1 and not external_speaker_embedding_dim:
|
if num_speakers > 1 and not speaker_embedding_dim:
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
||||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
@ -151,7 +152,7 @@ class GlowTTS(nn.Module):
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.external_speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -181,7 +182,7 @@ class GlowTTS(nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, x_lengths, g=None):
|
def inference(self, x, x_lengths, g=None):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.external_speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||||
|
@ -200,7 +201,7 @@ class GlowTTS(nn.Module):
|
||||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
|
|
||||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask
|
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
|
||||||
# decoder pass
|
# decoder pass
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
|
|
Loading…
Reference in New Issue