Reformat multi-speaker handling in GlowTTS

pull/792/head
Eren Gölge 2021-09-06 14:27:13 +00:00
parent 8d41060d36
commit d847a68e42
1 changed files with 14 additions and 10 deletions

View File

@ -109,6 +109,10 @@ class GlowTTS(BaseTTS):
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
if config.use_d_vector_file:
self.external_d_vector_dim = config.d_vector_dim
else:
self.external_d_vector_dim = 0
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = self.c_in_channels
@ -129,7 +133,7 @@ class GlowTTS(BaseTTS):
return y_mean, y_log_scale, o_attn_dur
def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -143,8 +147,8 @@ class GlowTTS(BaseTTS):
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.d_vector_dim:
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
@ -181,7 +185,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
@ -198,12 +202,11 @@ class GlowTTS(BaseTTS):
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.external_d_vector_dim:
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
@ -243,7 +246,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None}
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -275,7 +278,7 @@ class GlowTTS(BaseTTS):
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
@ -326,8 +329,9 @@ class GlowTTS(BaseTTS):
mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
loss_dict = criterion(
outputs["model_outputs"],